Commit 82295dbf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable black for mobile-vision

Summary:
https://fb.workplace.com/groups/pythonfoundation/posts/2990917737888352

Remove `mobile-vision` from opt-out list; leaving `mobile-vision/SNPE` opted out because of 3rd-party code.

arc lint --take BLACK --apply-patches --paths-cmd 'hg files mobile-vision'

allow-large-files

Reviewed By: sstsai-adl

Differential Revision: D30721093

fbshipit-source-id: 9e5c16d988b315b93a28038443ecfb92efd18ef8
parent a56c7e15
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
Modules to compute the matching cost and solve the corresponding LSAP. Modules to compute the matching cost and solve the corresponding LSAP.
""" """
import torch import torch
from detr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from torch import nn from torch import nn
from detr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
class HungarianMatcher(nn.Module): class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network """This class computes an assignment between the targets and the predictions of the network
...@@ -19,7 +18,13 @@ class HungarianMatcher(nn.Module): ...@@ -19,7 +18,13 @@ class HungarianMatcher(nn.Module):
while the others are un-matched (and thus treated as non-objects). while the others are un-matched (and thus treated as non-objects).
""" """
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, use_focal_loss=False): def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
use_focal_loss=False,
):
"""Creates the matcher """Creates the matcher
Params: Params:
...@@ -31,12 +36,14 @@ class HungarianMatcher(nn.Module): ...@@ -31,12 +36,14 @@ class HungarianMatcher(nn.Module):
self.cost_class = cost_class self.cost_class = cost_class
self.cost_bbox = cost_bbox self.cost_bbox = cost_bbox
self.cost_giou = cost_giou self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
self.use_focal_loss = use_focal_loss self.use_focal_loss = use_focal_loss
@torch.no_grad() @torch.no_grad()
def forward(self, outputs, targets): def forward(self, outputs, targets):
""" Performs the matching """Performs the matching
Params: Params:
outputs: This is a dict that contains at least these entries: outputs: This is a dict that contains at least these entries:
...@@ -61,7 +68,9 @@ class HungarianMatcher(nn.Module): ...@@ -61,7 +68,9 @@ class HungarianMatcher(nn.Module):
if self.use_focal_loss: if self.use_focal_loss:
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
else: else:
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] out_prob = (
outputs["pred_logits"].flatten(0, 1).softmax(-1)
) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes # Also concat the target labels and boxes
...@@ -74,29 +83,57 @@ class HungarianMatcher(nn.Module): ...@@ -74,29 +83,57 @@ class HungarianMatcher(nn.Module):
if self.use_focal_loss: if self.use_focal_loss:
alpha = 0.25 alpha = 0.25
gamma = 2.0 gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) neg_cost_class = (
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
)
pos_cost_class = (
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
)
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
else: else:
cost_class = -out_prob[:, tgt_ids] # shape [batch_size * num_queries, \sum_b NUM-BOX_b] cost_class = -out_prob[
:, tgt_ids
] # shape [batch_size * num_queries, \sum_b NUM-BOX_b]
# Compute the L1 cost between boxes # Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # shape [batch_size * num_queries,\sum_b NUM-BOX_b] cost_bbox = torch.cdist(
out_bbox, tgt_bbox, p=1
) # shape [batch_size * num_queries,\sum_b NUM-BOX_b]
# Compute the giou cost betwen boxes # Compute the giou cost betwen boxes
# shape [batch_size * num_queries, \sum_b NUM-BOX_b] # shape [batch_size * num_queries, \sum_b NUM-BOX_b]
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Final cost matrix # Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou C = (
C = C.view(bs, num_queries, -1).cpu() # shape [batch_size, num_queries, \sum_b NUM-BOX_b] self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
C = C.view(
bs, num_queries, -1
).cpu() # shape [batch_size, num_queries, \sum_b NUM-BOX_b]
sizes = [len(v["boxes"]) for v in targets] # shape [batch_size,] sizes = [len(v["boxes"]) for v in targets] # shape [batch_size,]
# each split c shape [batch_size, num_queries, NUM-BOX_b] # each split c shape [batch_size, num_queries, NUM-BOX_b]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] indices = [
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
]
# A list where each item is [row_indices, col_indices] # A list where each item is [row_indices, col_indices]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] return [
(
torch.as_tensor(i, dtype=torch.int64),
torch.as_tensor(j, dtype=torch.int64),
)
for i, j in indices
]
def build_matcher(args): def build_matcher(args):
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) return HungarianMatcher(
cost_class=args.set_cost_class,
cost_bbox=args.set_cost_bbox,
cost_giou=args.set_cost_giou,
)
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
Various positional encodings for the transformer. Various positional encodings for the transformer.
""" """
import math import math
import torch
from torch import nn
import torch
from detr.util.misc import NestedTensor from detr.util.misc import NestedTensor
from torch import nn
class PositionEmbeddingSine(nn.Module): class PositionEmbeddingSine(nn.Module):
...@@ -16,7 +16,15 @@ class PositionEmbeddingSine(nn.Module): ...@@ -16,7 +16,15 @@ class PositionEmbeddingSine(nn.Module):
This is a more standard version of the position embedding, very similar to the one This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images. used by the Attention is all you need paper, generalized to work on images.
""" """
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None, centered=False):
def __init__(
self,
num_pos_feats=64,
temperature=10000,
normalize=False,
scale=None,
centered=False,
):
super().__init__() super().__init__()
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
self.temperature = temperature self.temperature = temperature
...@@ -47,13 +55,25 @@ class PositionEmbeddingSine(nn.Module): ...@@ -47,13 +55,25 @@ class PositionEmbeddingSine(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # shape (N, ) dim_t = self.temperature ** (
2 * (dim_t // 2) / self.num_pos_feats
) # shape (N, )
pos_x = x_embed[:, :, :, None] / dim_t # shape (B, H, W, N) pos_x = x_embed[:, :, :, None] / dim_t # shape (B, H, W, N)
pos_y = y_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N) pos_x = torch.stack(
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N) (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # shape (B, 2*N, H, W) ).flatten(
3
) # shape (B, H, W, N)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(
3
) # shape (B, H, W, N)
pos = torch.cat((pos_y, pos_x), dim=3).permute(
0, 3, 1, 2
) # shape (B, 2*N, H, W)
return pos return pos
...@@ -61,6 +81,7 @@ class PositionEmbeddingLearned(nn.Module): ...@@ -61,6 +81,7 @@ class PositionEmbeddingLearned(nn.Module):
""" """
Absolute pos embedding, learned. Absolute pos embedding, learned.
""" """
def __init__(self, num_pos_feats=256): def __init__(self, num_pos_feats=256):
super().__init__() super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats) self.row_embed = nn.Embedding(50, num_pos_feats)
...@@ -78,19 +99,27 @@ class PositionEmbeddingLearned(nn.Module): ...@@ -78,19 +99,27 @@ class PositionEmbeddingLearned(nn.Module):
j = torch.arange(h, device=x.device) j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i) x_emb = self.col_embed(i)
y_emb = self.row_embed(j) y_emb = self.row_embed(j)
pos = torch.cat([ pos = (
x_emb.unsqueeze(0).repeat(h, 1, 1), torch.cat(
y_emb.unsqueeze(1).repeat(1, w, 1), [
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos return pos
def build_position_encoding(args): def build_position_encoding(args):
N_steps = args.hidden_dim // 2 N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'): if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments # TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True) position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'): elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps) position_embedding = PositionEmbeddingLearned(N_steps)
else: else:
raise ValueError(f"not supported {args.position_embedding}") raise ValueError(f"not supported {args.position_embedding}")
......
...@@ -8,14 +8,13 @@ import io ...@@ -8,14 +8,13 @@ import io
from collections import defaultdict from collections import defaultdict
from typing import List, Optional from typing import List, Optional
import detr.util.box_ops as box_ops
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from PIL import Image
import detr.util.box_ops as box_ops
from detr.util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list from detr.util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
from PIL import Image
from torch import Tensor
try: try:
from panopticapi.utils import id2rgb, rgb2id from panopticapi.utils import id2rgb, rgb2id
...@@ -33,8 +32,12 @@ class DETRsegm(nn.Module): ...@@ -33,8 +32,12 @@ class DETRsegm(nn.Module):
p.requires_grad_(False) p.requires_grad_(False)
hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) self.bbox_attention = MHAttentionMap(
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) hidden_dim, hidden_dim, nheads, dropout=0.0
)
self.mask_head = MaskHeadSmallConv(
hidden_dim + nheads, [1024, 512, 256], hidden_dim
)
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)): if isinstance(samples, (list, torch.Tensor)):
...@@ -46,19 +49,27 @@ class DETRsegm(nn.Module): ...@@ -46,19 +49,27 @@ class DETRsegm(nn.Module):
src, mask = features[-1].decompose() src, mask = features[-1].decompose()
assert mask is not None assert mask is not None
src_proj = self.detr.input_proj(src) src_proj = self.detr.input_proj(src)
hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) hs, memory = self.detr.transformer(
src_proj, mask, self.detr.query_embed.weight, pos[-1]
)
outputs_class = self.detr.class_embed(hs) outputs_class = self.detr.class_embed(hs)
outputs_coord = self.detr.bbox_embed(hs).sigmoid() outputs_coord = self.detr.bbox_embed(hs).sigmoid()
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.detr.aux_loss: if self.detr.aux_loss:
out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) out["aux_outputs"] = self.detr._set_aux_loss(outputs_class, outputs_coord)
# FIXME h_boxes takes the last one computed, keep this in mind # FIXME h_boxes takes the last one computed, keep this in mind
bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) seg_masks = self.mask_head(
outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) src_proj,
bbox_mask,
[features[2].tensors, features[1].tensors, features[0].tensors],
)
outputs_seg_masks = seg_masks.view(
bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
)
out["pred_masks"] = outputs_seg_masks out["pred_masks"] = outputs_seg_masks
return out return out
...@@ -77,7 +88,14 @@ class MaskHeadSmallConv(nn.Module): ...@@ -77,7 +88,14 @@ class MaskHeadSmallConv(nn.Module):
def __init__(self, dim, fpn_dims, context_dim): def __init__(self, dim, fpn_dims, context_dim):
super().__init__() super().__init__()
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] inter_dims = [
dim,
context_dim // 2,
context_dim // 4,
context_dim // 8,
context_dim // 16,
context_dim // 64,
]
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
self.gn1 = torch.nn.GroupNorm(8, dim) self.gn1 = torch.nn.GroupNorm(8, dim)
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
...@@ -159,9 +177,19 @@ class MHAttentionMap(nn.Module): ...@@ -159,9 +177,19 @@ class MHAttentionMap(nn.Module):
def forward(self, q, k, mask: Optional[Tensor] = None): def forward(self, q, k, mask: Optional[Tensor] = None):
q = self.q_linear(q) q = self.q_linear(q)
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) k = F.conv2d(
qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias
kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) )
qh = q.view(
q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads
)
kh = k.view(
k.shape[0],
self.num_heads,
self.hidden_dim // self.num_heads,
k.shape[-2],
k.shape[-1],
)
weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
if mask is not None: if mask is not None:
...@@ -189,7 +217,9 @@ def dice_loss(inputs, targets, num_boxes): ...@@ -189,7 +217,9 @@ def dice_loss(inputs, targets, num_boxes):
return loss.sum() / num_boxes return loss.sum() / num_boxes
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): def sigmoid_focal_loss(
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2
):
""" """
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args: Args:
...@@ -227,10 +257,14 @@ class PostProcessSegm(nn.Module): ...@@ -227,10 +257,14 @@ class PostProcessSegm(nn.Module):
assert len(orig_target_sizes) == len(max_target_sizes) assert len(orig_target_sizes) == len(max_target_sizes)
max_h, max_w = max_target_sizes.max(0)[0].tolist() max_h, max_w = max_target_sizes.max(0)[0].tolist()
outputs_masks = outputs["pred_masks"].squeeze(2) outputs_masks = outputs["pred_masks"].squeeze(2)
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) outputs_masks = F.interpolate(
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
)
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): for i, (cur_mask, t, tt) in enumerate(
zip(outputs_masks, max_target_sizes, orig_target_sizes)
):
img_h, img_w = t[0], t[1] img_h, img_w = t[0], t[1]
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
results[i]["masks"] = F.interpolate( results[i]["masks"] = F.interpolate(
...@@ -242,7 +276,7 @@ class PostProcessSegm(nn.Module): ...@@ -242,7 +276,7 @@ class PostProcessSegm(nn.Module):
class PostProcessPanoptic(nn.Module): class PostProcessPanoptic(nn.Module):
"""This class converts the output of the model to the final panoptic result, in the format expected by the """This class converts the output of the model to the final panoptic result, in the format expected by the
coco panoptic API """ coco panoptic API"""
def __init__(self, is_thing_map, threshold=0.85): def __init__(self, is_thing_map, threshold=0.85):
""" """
...@@ -255,19 +289,23 @@ class PostProcessPanoptic(nn.Module): ...@@ -255,19 +289,23 @@ class PostProcessPanoptic(nn.Module):
self.threshold = threshold self.threshold = threshold
self.is_thing_map = is_thing_map self.is_thing_map = is_thing_map
def forward(self, outputs, processed_sizes, target_sizes=None): #noqa: C901 def forward(self, outputs, processed_sizes, target_sizes=None): # noqa: C901
""" This function computes the panoptic prediction from the model's predictions. """This function computes the panoptic prediction from the model's predictions.
Parameters: Parameters:
outputs: This is a dict coming directly from the model. See the model doc for the content. outputs: This is a dict coming directly from the model. See the model doc for the content.
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
model, ie the size after data augmentation but before batching. model, ie the size after data augmentation but before batching.
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
of each prediction. If left to None, it will default to the processed_sizes of each prediction. If left to None, it will default to the processed_sizes
""" """
if target_sizes is None: if target_sizes is None:
target_sizes = processed_sizes target_sizes = processed_sizes
assert len(processed_sizes) == len(target_sizes) assert len(processed_sizes) == len(target_sizes)
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] out_logits, raw_masks, raw_boxes = (
outputs["pred_logits"],
outputs["pred_masks"],
outputs["pred_boxes"],
)
assert len(out_logits) == len(raw_masks) == len(target_sizes) assert len(out_logits) == len(raw_masks) == len(target_sizes)
preds = [] preds = []
...@@ -281,12 +319,16 @@ class PostProcessPanoptic(nn.Module): ...@@ -281,12 +319,16 @@ class PostProcessPanoptic(nn.Module):
): ):
# we filter empty queries and detection below threshold # we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1) scores, labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (
scores > self.threshold
)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_classes = cur_classes[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = interpolate(
cur_masks[:, None], to_tuple(size), mode="bilinear"
).squeeze(1)
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
h, w = cur_masks.shape[-2:] h, w = cur_masks.shape[-2:]
...@@ -322,10 +364,14 @@ class PostProcessPanoptic(nn.Module): ...@@ -322,10 +364,14 @@ class PostProcessPanoptic(nn.Module):
final_h, final_w = to_tuple(target_size) final_h, final_w = to_tuple(target_size)
seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) seg_img = seg_img.resize(
size=(final_w, final_h), resample=Image.NEAREST
)
np_seg_img = ( np_seg_img = (
torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
.view(final_h, final_w, 3)
.numpy()
) )
m_id = torch.from_numpy(rgb2id(np_seg_img)) m_id = torch.from_numpy(rgb2id(np_seg_img))
...@@ -339,7 +385,9 @@ class PostProcessPanoptic(nn.Module): ...@@ -339,7 +385,9 @@ class PostProcessPanoptic(nn.Module):
# We know filter empty masks as long as we find some # We know filter empty masks as long as we find some
while True: while True:
filtered_small = torch.as_tensor( filtered_small = torch.as_tensor(
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device [area[i] <= 4 for i, c in enumerate(cur_classes)],
dtype=torch.bool,
device=keep.device,
) )
if filtered_small.any().item(): if filtered_small.any().item():
cur_scores = cur_scores[~filtered_small] cur_scores = cur_scores[~filtered_small]
...@@ -355,11 +403,21 @@ class PostProcessPanoptic(nn.Module): ...@@ -355,11 +403,21 @@ class PostProcessPanoptic(nn.Module):
segments_info = [] segments_info = []
for i, a in enumerate(area): for i, a in enumerate(area):
cat = cur_classes[i].item() cat = cur_classes[i].item()
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) segments_info.append(
{
"id": i,
"isthing": self.is_thing_map[cat],
"category_id": cat,
"area": a,
}
)
del cur_classes del cur_classes
with io.BytesIO() as out: with io.BytesIO() as out:
seg_img.save(out, format="PNG") seg_img.save(out, format="PNG")
predictions = {"png_string": out.getvalue(), "segments_info": segments_info} predictions = {
"png_string": out.getvalue(),
"segments_info": segments_info,
}
preds.append(predictions) preds.append(predictions)
return preds return preds
...@@ -18,23 +18,38 @@ from torch import nn, Tensor ...@@ -18,23 +18,38 @@ from torch import nn, Tensor
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, self,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, d_model=512,
activation="relu", normalize_before=False, nhead=8,
return_intermediate_dec=False): num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__() super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, encoder_layer = TransformerEncoderLayer(
dropout, activation, normalize_before) d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, decoder_layer = TransformerDecoderLayer(
dropout, activation, normalize_before) d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model) decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, self.decoder = TransformerDecoder(
return_intermediate=return_intermediate_dec) decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
)
self._reset_parameters() self._reset_parameters()
...@@ -63,30 +78,41 @@ class Transformer(nn.Module): ...@@ -63,30 +78,41 @@ class Transformer(nn.Module):
# memory shape (L, B, C) # memory shape (L, B, C)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# hs shape (NUM_LEVEL, S, B, C) # hs shape (NUM_LEVEL, S, B, C)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, hs = self.decoder(
pos=pos_embed, query_pos=query_embed) tgt,
memory,
memory_key_padding_mask=mask,
pos=pos_embed,
query_pos=query_embed,
)
# return shape (NUM_LEVEL, B, S, C) and (B, C, H, W) # return shape (NUM_LEVEL, B, S, C) and (B, C, H, W)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None): def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__() super().__init__()
self.layers = _get_clones(encoder_layer, num_layers) self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers self.num_layers = num_layers
self.norm = norm self.norm = norm
def forward(self, src, def forward(
mask: Optional[Tensor] = None, self,
src_key_padding_mask: Optional[Tensor] = None, src,
pos: Optional[Tensor] = None): mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src output = src
# mask, shape (L, L) # mask, shape (L, L)
# src_key_padding_mask, shape (B, L) # src_key_padding_mask, shape (B, L)
for layer in self.layers: for layer in self.layers:
output = layer(output, src_mask=mask, output = layer(
src_key_padding_mask=src_key_padding_mask, pos=pos) output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
pos=pos,
)
if self.norm is not None: if self.norm is not None:
output = self.norm(output) output = self.norm(output)
...@@ -95,7 +121,6 @@ class TransformerEncoder(nn.Module): ...@@ -95,7 +121,6 @@ class TransformerEncoder(nn.Module):
class TransformerDecoder(nn.Module): class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__() super().__init__()
self.layers = _get_clones(decoder_layer, num_layers) self.layers = _get_clones(decoder_layer, num_layers)
...@@ -103,13 +128,17 @@ class TransformerDecoder(nn.Module): ...@@ -103,13 +128,17 @@ class TransformerDecoder(nn.Module):
self.norm = norm self.norm = norm
self.return_intermediate = return_intermediate self.return_intermediate = return_intermediate
def forward(self, tgt, memory, def forward(
tgt_mask: Optional[Tensor] = None, self,
memory_mask: Optional[Tensor] = None, tgt,
tgt_key_padding_mask: Optional[Tensor] = None, memory,
memory_key_padding_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt output = tgt
intermediate = [] intermediate = []
...@@ -119,11 +148,16 @@ class TransformerDecoder(nn.Module): ...@@ -119,11 +148,16 @@ class TransformerDecoder(nn.Module):
# memory_mask shape (L, S) # memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S) # memory_key_padding_mask shape (B, S)
for layer in self.layers: for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask, output = layer(
memory_mask=memory_mask, output,
tgt_key_padding_mask=tgt_key_padding_mask, memory,
memory_key_padding_mask=memory_key_padding_mask, tgt_mask=tgt_mask,
pos=pos, query_pos=query_pos) memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate: if self.return_intermediate:
intermediate.append(self.norm(output)) intermediate.append(self.norm(output))
...@@ -140,9 +174,15 @@ class TransformerDecoder(nn.Module): ...@@ -140,9 +174,15 @@ class TransformerDecoder(nn.Module):
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
def __init__(
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, self,
activation="relu", normalize_before=False): d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model # Implementation of Feedforward model
...@@ -161,16 +201,19 @@ class TransformerEncoderLayer(nn.Module): ...@@ -161,16 +201,19 @@ class TransformerEncoderLayer(nn.Module):
def with_pos_embed(self, tensor, pos: Optional[Tensor]): def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_post(self, def forward_post(
src, self,
src_mask: Optional[Tensor] = None, src,
src_key_padding_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None): src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos) # shape (L, B, D) q = k = self.with_pos_embed(src, pos) # shape (L, B, D)
# src mask, shape (L, L) # src mask, shape (L, L)
# src_key_padding_mask: shape (B, L) # src_key_padding_mask: shape (B, L)
src2 = self.self_attn(q, k, src, attn_mask=src_mask, src2 = self.self_attn(
key_padding_mask=src_key_padding_mask)[0] q, k, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src = self.norm1(src) src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
...@@ -178,33 +221,46 @@ class TransformerEncoderLayer(nn.Module): ...@@ -178,33 +221,46 @@ class TransformerEncoderLayer(nn.Module):
src = self.norm2(src) src = self.norm2(src)
return src return src
def forward_pre(self, src, def forward_pre(
src_mask: Optional[Tensor] = None, self,
src_key_padding_mask: Optional[Tensor] = None, src,
pos: Optional[Tensor] = None): src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src) src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos) q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, src2, attn_mask=src_mask, src2 = self.self_attn(
key_padding_mask=src_key_padding_mask)[0] q, k, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src2 = self.norm2(src) src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2) src = src + self.dropout2(src2)
return src return src
def forward(self, src, def forward(
src_mask: Optional[Tensor] = None, self,
src_key_padding_mask: Optional[Tensor] = None, src,
pos: Optional[Tensor] = None): src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before: if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
def __init__(
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, self,
activation="relu", normalize_before=False): d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
...@@ -226,28 +282,36 @@ class TransformerDecoderLayer(nn.Module): ...@@ -226,28 +282,36 @@ class TransformerDecoderLayer(nn.Module):
def with_pos_embed(self, tensor, pos: Optional[Tensor]): def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory, def forward_post(
tgt_mask: Optional[Tensor] = None, self,
memory_mask: Optional[Tensor] = None, tgt,
tgt_key_padding_mask: Optional[Tensor] = None, memory,
memory_key_padding_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# tgt shape (L, B, C) # tgt shape (L, B, C)
# tgt_mask shape (L, L) # tgt_mask shape (L, L)
# tgt_key_padding_mask shape (B, L) # tgt_key_padding_mask shape (B, L)
q = k = self.with_pos_embed(tgt, query_pos) q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask, tgt2 = self.self_attn(
key_padding_mask=tgt_key_padding_mask)[0] q, k, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
# memory_mask shape (L, S) # memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S) # memory_key_padding_mask shape (B, S)
# query_pos shape (L, B, C) # query_pos shape (L, B, C)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos), tgt2 = self.multihead_attn(
self.with_pos_embed(memory, pos), self.with_pos_embed(tgt, query_pos),
memory, attn_mask=memory_mask, self.with_pos_embed(memory, pos),
key_padding_mask=memory_key_padding_mask)[0] memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
...@@ -256,41 +320,69 @@ class TransformerDecoderLayer(nn.Module): ...@@ -256,41 +320,69 @@ class TransformerDecoderLayer(nn.Module):
# return tgt shape (L, B, C) # return tgt shape (L, B, C)
return tgt return tgt
def forward_pre(self, tgt, memory, def forward_pre(
tgt_mask: Optional[Tensor] = None, self,
memory_mask: Optional[Tensor] = None, tgt,
tgt_key_padding_mask: Optional[Tensor] = None, memory,
memory_key_padding_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt) tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos) q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, tgt2, attn_mask=tgt_mask, tgt2 = self.self_attn(
key_padding_mask=tgt_key_padding_mask)[0] q, k, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt) tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt2, query_pos), tgt2 = self.multihead_attn(
self.with_pos_embed(memory, pos), self.with_pos_embed(tgt2, query_pos),
memory, attn_mask=memory_mask, self.with_pos_embed(memory, pos),
key_padding_mask=memory_key_padding_mask)[0] memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt) tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2) tgt = tgt + self.dropout3(tgt2)
return tgt return tgt
def forward(self, tgt, memory, def forward(
tgt_mask: Optional[Tensor] = None, self,
memory_mask: Optional[Tensor] = None, tgt,
tgt_key_padding_mask: Optional[Tensor] = None, memory,
memory_key_padding_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before: if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask, return self.forward_pre(
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) tgt,
return self.forward_post(tgt, memory, tgt_mask, memory_mask, memory,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
def _get_clones(module, N): def _get_clones(module, N):
...@@ -318,4 +410,4 @@ def _get_activation_fn(activation): ...@@ -318,4 +410,4 @@ def _get_activation_fn(activation):
return F.gelu return F.gelu
if activation == "glu": if activation == "glu":
return F.glu return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
...@@ -9,15 +9,15 @@ ...@@ -9,15 +9,15 @@
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function
from __future__ import division from __future__ import division
from __future__ import print_function
import warnings
import math import math
import warnings
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.init import xavier_uniform_, constant_ from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction from ..functions import MSDeformAttnFunction
...@@ -25,8 +25,10 @@ from ..functions import MSDeformAttnFunction ...@@ -25,8 +25,10 @@ from ..functions import MSDeformAttnFunction
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) raise ValueError(
return (n & (n-1) == 0) and n != 0 "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
)
return (n & (n - 1) == 0) and n != 0
class MSDeformAttn(nn.Module): class MSDeformAttn(nn.Module):
...@@ -40,12 +42,18 @@ class MSDeformAttn(nn.Module): ...@@ -40,12 +42,18 @@ class MSDeformAttn(nn.Module):
""" """
super().__init__() super().__init__()
if d_model % n_heads != 0: if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) raise ValueError(
"d_model must be divisible by n_heads, but got {} and {}".format(
d_model, n_heads
)
)
_d_per_head = d_model // n_heads _d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head): if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " warnings.warn(
"which is more efficient in our CUDA implementation.") "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation."
)
self.im2col_step = 64 self.im2col_step = 64
...@@ -62,25 +70,39 @@ class MSDeformAttn(nn.Module): ...@@ -62,25 +70,39 @@ class MSDeformAttn(nn.Module):
self._reset_parameters() self._reset_parameters()
def _reset_parameters(self): def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.) constant_(self.sampling_offsets.weight.data, 0.0)
# shape (num_heads,) # shape (num_heads,)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
2.0 * math.pi / self.n_heads
)
# shape (num_heads, 2) # shape (num_heads, 2)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# shape (num_heads, num_levels, num_points, 2) # shape (num_heads, num_levels, num_points, 2)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points): for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1 grid_init[:, :, i, :] *= i + 1
with torch.no_grad(): with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.) constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data) xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.) constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data) xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.) constant_(self.output_proj.bias.data, 0.0)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): def forward(
self,
query,
reference_points,
input_flatten,
input_spatial_shapes,
input_level_start_index,
input_padding_mask=None,
):
""" """
:param query (N, Length_{query}, C) :param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
...@@ -100,21 +122,45 @@ class MSDeformAttn(nn.Module): ...@@ -100,21 +122,45 @@ class MSDeformAttn(nn.Module):
if input_padding_mask is not None: if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) sampling_offsets = self.sampling_offsets(query).view(
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) )
attention_weights = self.attention_weights(query).view(
N, Len_q, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
N, Len_q, self.n_heads, self.n_levels, self.n_points
)
# N, Len_q, n_heads, n_levels, n_points, 2 # N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2: if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) offset_normalizer = torch.stack(
sampling_locations = reference_points[:, :, None, :, None, :] \ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] )
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4: elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \ sampling_locations = (
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 reference_points[:, :, None, :, None, :2]
+ sampling_offsets
/ self.n_points
* reference_points[:, :, None, :, None, 2:]
* 0.5
)
else: else:
raise ValueError( raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
output = MSDeformAttnFunction.apply( output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) value,
input_spatial_shapes,
input_level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output) output = self.output_proj(output)
return output return output
...@@ -4,9 +4,9 @@ from d2go.config import CfgNode as CN ...@@ -4,9 +4,9 @@ from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from detr.d2 import DetrDatasetMapper, add_detr_config
from detr.backbone.deit import add_deit_backbone_config from detr.backbone.deit import add_deit_backbone_config
from detr.backbone.pit import add_pit_backbone_config from detr.backbone.pit import add_pit_backbone_config
from detr.d2 import DetrDatasetMapper, add_detr_config
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
......
...@@ -10,15 +10,13 @@ from torchvision.ops.boxes import box_area ...@@ -10,15 +10,13 @@ from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x): def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1) x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x): def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1) x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)
...@@ -79,11 +77,11 @@ def masks_to_boxes(masks): ...@@ -79,11 +77,11 @@ def masks_to_boxes(masks):
x = torch.arange(0, w, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x) y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0)) x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0)) y_mask = masks * y.unsqueeze(0)
y_max = y_mask.flatten(1).max(-1)[0] y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
......
This diff is collapsed.
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
""" """
Plotting utilities to visualize training logs. Plotting utilities to visualize training logs.
""" """
import torch from pathlib import Path, PurePath
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import torch
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): def plot_logs(
''' logs,
fields=("class_error", "loss_bbox_unscaled", "mAP"),
ewm_col=0,
log_name="log.txt",
):
"""
Function to plot specific fields from training log(s). Plots both training and test results. Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
...@@ -24,7 +29,7 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col ...@@ -24,7 +29,7 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col
:: Outputs - matplotlib plots of results in fields, color coded for each log file. :: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results. - solid lines are training results, dashed lines are test results.
''' """
func_name = "plot_utils.py::plot_logs" func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
...@@ -33,17 +38,25 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col ...@@ -33,17 +38,25 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col
if not isinstance(logs, list): if not isinstance(logs, list):
if isinstance(logs, PurePath): if isinstance(logs, PurePath):
logs = [logs] logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") print(
f"{func_name} info: logs param expects a list argument, converted to list[Path]."
)
else: else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ raise ValueError(
Expect list[Path] or single Path obj, received {type(logs)}") f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}"
)
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for _, dir in enumerate(logs): for _, dir in enumerate(logs):
if not isinstance(dir, PurePath): if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") raise ValueError(
f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}"
)
if not dir.exists(): if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") raise ValueError(
f"{func_name} - invalid directory in logs argument:\n{dir}"
)
# verify log_name exists # verify log_name exists
fn = Path(dir / log_name) fn = Path(dir / log_name)
if not fn.exists(): if not fn.exists():
...@@ -58,52 +71,57 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col ...@@ -58,52 +71,57 @@ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields): for j, field in enumerate(fields):
if field == 'mAP': if field == "mAP":
coco_eval = pd.DataFrame( coco_eval = (
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] pd.DataFrame(np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1])
).ewm(com=ewm_col).mean() .ewm(com=ewm_col)
.mean()
)
axs[j].plot(coco_eval, c=color) axs[j].plot(coco_eval, c=color)
else: else:
df.interpolate().ewm(com=ewm_col).mean().plot( df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'], y=[f"train_{field}", f"test_{field}"],
ax=axs[j], ax=axs[j],
color=[color] * 2, color=[color] * 2,
style=['-', '--'] style=["-", "--"],
) )
for ax, field in zip(axs, fields): for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs]) ax.legend([Path(p).name for p in logs])
ax.set_title(field) ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'): def plot_precision_recall(files, naming_scheme="iter"):
if naming_scheme == 'exp_id': if naming_scheme == "exp_id":
# name becomes exp_id # name becomes exp_id
names = [f.parts[-3] for f in files] names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter': elif naming_scheme == "iter":
names = [f.stem for f in files] names = [f.stem for f in files]
else: else:
raise ValueError(f'not supported {naming_scheme}') raise ValueError(f"not supported {naming_scheme}")
fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): for f, color, name in zip(
files, sns.color_palette("Blues", n_colors=len(files)), names
):
data = torch.load(f) data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det # precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision'] precision = data["precision"]
recall = data['params'].recThrs recall = data["params"].recThrs
scores = data['scores'] scores = data["scores"]
# take precision for all classes, all areas and 100 detections # take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1) precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1) scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean() prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean() rec = data["recall"][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + print(
f'score={scores.mean():0.3f}, ' + f"{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, "
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + f"score={scores.mean():0.3f}, "
) + f"f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}"
)
axs[0].plot(recall, precision, c=color) axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color) axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall') axs[0].set_title("Precision / Recall")
axs[0].legend(names) axs[0].legend(names)
axs[1].set_title('Scores / Recall') axs[1].set_title("Scores / Recall")
axs[1].legend(names) axs[1].legend(names)
return fig, axs return fig, axs
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import logging
import unittest import unittest
from detr.backbone.deit import add_deit_backbone_config
from detr.backbone.pit import add_pit_backbone_config
import torch import torch
from detectron2.utils.file_io import PathManager
from detectron2.checkpoint import DetectionCheckpointer
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling import BACKBONE_REGISTRY from detectron2.modeling import BACKBONE_REGISTRY
from detectron2.utils.file_io import PathManager
from detr.backbone.deit import add_deit_backbone_config
from detr.backbone.pit import add_pit_backbone_config
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# avoid testing on sandcastle due to access to manifold # avoid testing on sandcastle due to access to manifold
USE_CUDA = torch.cuda.device_count() > 0 USE_CUDA = torch.cuda.device_count() > 0
class TestTransformerBackbone(unittest.TestCase): class TestTransformerBackbone(unittest.TestCase):
@unittest.skipIf(not USE_CUDA,"avoid testing on sandcastle due to access to manifold") @unittest.skipIf(
not USE_CUDA, "avoid testing on sandcastle due to access to manifold"
)
def test_deit_model(self): def test_deit_model(self):
cfg = CN() cfg = CN()
cfg.MODEL = CN() cfg.MODEL = CN()
...@@ -49,9 +52,10 @@ class TestTransformerBackbone(unittest.TestCase): ...@@ -49,9 +52,10 @@ class TestTransformerBackbone(unittest.TestCase):
x = torch.rand(1, 3, input_size_h, input_size_w) x = torch.rand(1, 3, input_size_h, input_size_w)
y = model(x) y = model(x)
print(f"x.shape: {x.shape}, y.shape: {y.shape}") print(f"x.shape: {x.shape}, y.shape: {y.shape}")
@unittest.skipIf(not USE_CUDA,"avoid testing on sandcastle due to access to manifold") @unittest.skipIf(
not USE_CUDA, "avoid testing on sandcastle due to access to manifold"
)
def test_pit_model(self): def test_pit_model(self):
cfg = CN() cfg = CN()
cfg.MODEL = CN() cfg.MODEL = CN()
......
...@@ -13,6 +13,7 @@ from d2go.utils.testing.data_loader_helper import create_local_dataset ...@@ -13,6 +13,7 @@ from d2go.utils.testing.data_loader_helper import create_local_dataset
# RUN: # RUN:
# buck test mobile-vision/d2go/projects_oss/detr:test_detr_runner # buck test mobile-vision/d2go/projects_oss/detr:test_detr_runner
def _get_cfg(runner, output_dir, dataset_name): def _get_cfg(runner, output_dir, dataset_name):
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment