Unverified Commit a6d39f6a authored by Yuliang Liu's avatar Yuliang Liu Committed by GitHub
Browse files

Merge pull request #39 from Yuliang-Liu/dev

Data generation
parents c7341cda 2189c3c4
import torch
from fvcore.nn import giou_loss, smooth_l1_loss
from torch import nn
from torch.nn import functional as F
import fvcore.nn.weight_init as weight_init
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats
__all__ = ["GRiTFastRCNNOutputLayers"]
class GRiTFastRCNNOutputLayers(FastRCNNOutputLayers):
@configurable
def __init__(
self,
input_shape: ShapeSpec,
**kwargs,
):
super().__init__(
input_shape=input_shape,
**kwargs,
)
input_size = input_shape.channels * \
(input_shape.width or 1) * (input_shape.height or 1)
self.bbox_pred = nn.Sequential(
nn.Linear(input_size, input_size),
nn.ReLU(inplace=True),
nn.Linear(input_size, 4)
)
weight_init.c2_xavier_fill(self.bbox_pred[0])
nn.init.normal_(self.bbox_pred[-1].weight, std=0.001)
nn.init.constant_(self.bbox_pred[-1].bias, 0)
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
return ret
def losses(self, predictions, proposals):
scores, proposal_deltas = predictions
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
)
num_classes = self.num_classes
_log_classification_stats(scores, gt_classes)
if len(proposals):
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
gt_boxes = cat(
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
dim=0,
)
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
return {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes,
num_classes=num_classes)
}
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
if pred_class_logits.numel() == 0:
return pred_class_logits.new_zeros([1])[0]
loss = F.cross_entropy(
pred_class_logits, gt_classes, reduction="mean")
return loss
def box_reg_loss(
self, proposal_boxes, gt_boxes, pred_deltas, gt_classes,
num_classes=-1):
num_classes = num_classes if num_classes > 0 else self.num_classes
box_dim = proposal_boxes.shape[1]
fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0]
if pred_deltas.shape[1] == box_dim:
fg_pred_deltas = pred_deltas[fg_inds]
else:
fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[
fg_inds, gt_classes[fg_inds]
]
if self.box_reg_loss_type == "smooth_l1":
gt_pred_deltas = self.box2box_transform.get_deltas(
proposal_boxes[fg_inds],
gt_boxes[fg_inds],
)
loss_box_reg = smooth_l1_loss(
fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum"
)
elif self.box_reg_loss_type == "giou":
fg_pred_boxes = self.box2box_transform.apply_deltas(
fg_pred_deltas, proposal_boxes[fg_inds]
)
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
else:
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
return loss_box_reg / max(gt_classes.numel(), 1.0)
def predict_probs(self, predictions, proposals):
scores = predictions[0]
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)
def forward(self, x):
if x.dim() > 2:
x = torch.flatten(x, start_dim=1)
scores = []
cls_scores = self.cls_score(x)
scores.append(cls_scores)
scores = torch.cat(scores, dim=1)
proposal_deltas = self.bbox_pred(x)
return scores, proposal_deltas
\ No newline at end of file
import math
import torch
from typing import Dict, List, Optional, Tuple, Union
from detectron2.config import configurable
from detectron2.structures import Boxes, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
from detectron2.modeling.poolers import ROIPooler
from detectron2.layers import batched_nms
from .grit_fast_rcnn import GRiTFastRCNNOutputLayers
from ..text.text_decoder import TransformerDecoderTextualHead, GRiTTextDecoder, AutoRegressiveBeamSearch
from ..text.load_text_token import LoadTextTokens
from transformers import BertTokenizer
from grit.data.custom_dataset_mapper import ObjDescription
from ..soft_nms import batched_soft_nms
import logging
logger = logging.getLogger(__name__)
@ROI_HEADS_REGISTRY.register()
class GRiTROIHeadsAndTextDecoder(CascadeROIHeads):
@configurable
def __init__(
self,
*,
text_decoder_transformer,
train_task: list,
test_task: str,
mult_proposal_score: bool = False,
mask_weight: float = 1.0,
object_feat_pooler=None,
soft_nms_enabled=False,
beam_size=1,
**kwargs,
):
super().__init__(**kwargs)
self.mult_proposal_score = mult_proposal_score
self.mask_weight = mask_weight
self.object_feat_pooler = object_feat_pooler
self.soft_nms_enabled = soft_nms_enabled
self.test_task = test_task
self.beam_size = beam_size
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
self.tokenizer = tokenizer
assert test_task in train_task, 'GRiT has not been trained on {} task, ' \
'please verify the task name or train a new ' \
'GRiT on {} task'.format(test_task, test_task)
task_begin_tokens = {}
for i, task in enumerate(train_task):
if i == 0:
task_begin_tokens[task] = tokenizer.cls_token_id
else:
task_begin_tokens[task] = 103 + i
self.task_begin_tokens = task_begin_tokens
beamsearch_decode = AutoRegressiveBeamSearch(
end_token_id=tokenizer.sep_token_id,
max_steps=40,
beam_size=beam_size,
objectdet=test_task == "ObjectDet",
per_node_beam_size=1,
)
self.text_decoder = GRiTTextDecoder(
text_decoder_transformer,
beamsearch_decode=beamsearch_decode,
begin_token_id=task_begin_tokens[test_task],
loss_type='smooth',
tokenizer=tokenizer,
)
self.get_target_text_tokens = LoadTextTokens(tokenizer, max_text_len=40, padding='do_not_pad')
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
text_decoder_transformer = TransformerDecoderTextualHead(
object_feature_size=cfg.MODEL.FPN.OUT_CHANNELS,
vocab_size=cfg.TEXT_DECODER.VOCAB_SIZE,
hidden_size=cfg.TEXT_DECODER.HIDDEN_SIZE,
num_layers=cfg.TEXT_DECODER.NUM_LAYERS,
attention_heads=cfg.TEXT_DECODER.ATTENTION_HEADS,
feedforward_size=cfg.TEXT_DECODER.FEEDFORWARD_SIZE,
mask_future_positions=True,
padding_idx=0,
decoder_type='bert_en',
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
)
ret.update({
'text_decoder_transformer': text_decoder_transformer,
'train_task': cfg.MODEL.TRAIN_TASK,
'test_task': cfg.MODEL.TEST_TASK,
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
'soft_nms_enabled': cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED,
'beam_size': cfg.MODEL.BEAM_SIZE,
})
return ret
@classmethod
def _init_box_head(self, cfg, input_shape):
ret = super()._init_box_head(cfg, input_shape)
del ret['box_predictors']
cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
box_predictors = []
for box_head, bbox_reg_weights in zip(ret['box_heads'], \
cascade_bbox_reg_weights):
box_predictors.append(
GRiTFastRCNNOutputLayers(
cfg, box_head.output_shape,
box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
))
ret['box_predictors'] = box_predictors
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
object_feat_pooler = ROIPooler(
output_size=cfg.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
)
ret['object_feat_pooler'] = object_feat_pooler
return ret
def check_if_all_background(self, proposals, targets, stage):
all_background = True
for proposals_per_image in proposals:
if not (proposals_per_image.gt_classes == self.num_classes).all():
all_background = False
if all_background:
logger.info('all proposals are background at stage {}'.format(stage))
proposals[0].proposal_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].gt_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].objectness_logits[0] = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10)))
proposals[0].gt_classes[0] = targets[0].gt_classes[0]
proposals[0].gt_object_descriptions.data[0] = targets[0].gt_object_descriptions.data[0]
if 'foreground' in proposals[0].get_fields().keys():
proposals[0].foreground[0] = 1
return proposals
def _forward_box(self, features, proposals, targets=None, task="ObjectDet"):
if self.training:
proposals = self.check_if_all_background(proposals, targets, 0)
if (not self.training) and self.mult_proposal_score:
if len(proposals) > 0 and proposals[0].has('scores'):
proposal_scores = [p.get('scores') for p in proposals]
else:
proposal_scores = [p.get('objectness_logits') for p in proposals]
features = [features[f] for f in self.box_in_features]
head_outputs = []
prev_pred_boxes = None
image_sizes = [x.image_size for x in proposals]
for k in range(self.num_cascade_stages):
if k > 0:
proposals = self._create_proposals_from_boxes(
prev_pred_boxes, image_sizes,
logits=[p.objectness_logits for p in proposals])
if self.training:
proposals = self._match_and_label_boxes_GRiT(
proposals, k, targets)
proposals = self.check_if_all_background(proposals, targets, k)
predictions = self._run_stage(features, proposals, k)
prev_pred_boxes = self.box_predictor[k].predict_boxes(
(predictions[0], predictions[1]), proposals)
head_outputs.append((self.box_predictor[k], predictions, proposals))
if self.training:
object_features = self.object_feat_pooler(features, [x.proposal_boxes for x in proposals])
object_features = _ScaleGradient.apply(object_features, 1.0 / self.num_cascade_stages)
foreground = torch.cat([x.foreground for x in proposals])
object_features = object_features[foreground > 0]
object_descriptions = []
for x in proposals:
object_descriptions += x.gt_object_descriptions[x.foreground > 0].data
object_descriptions = ObjDescription(object_descriptions)
object_descriptions = object_descriptions.data
if len(object_descriptions) > 0:
begin_token = self.task_begin_tokens[task]
text_decoder_inputs = self.get_target_text_tokens(object_descriptions, object_features, begin_token)
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_inputs.update({'object_features': object_features})
text_decoder_loss = self.text_decoder(text_decoder_inputs)
else:
text_decoder_loss = head_outputs[0][1][0].new_zeros([1])[0]
losses = {}
storage = get_event_storage()
# RoI Head losses (For the proposal generator loss, please find it in grit.py)
for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
with storage.name_scope("stage{}".format(stage)):
stage_losses = predictor.losses(
(predictions[0], predictions[1]), proposals)
losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()})
# Text Decoder loss
losses.update({'text_decoder_loss': text_decoder_loss})
return losses
else:
scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
logits_per_stage = [(h[1][0],) for h in head_outputs]
scores = [
sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
for scores_per_image in zip(*scores_per_stage)
]
logits = [
sum(list(logits_per_image)) * (1.0 / self.num_cascade_stages)
for logits_per_image in zip(*logits_per_stage)
]
if self.mult_proposal_score:
scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)]
predictor, predictions, proposals = head_outputs[-1]
boxes = predictor.predict_boxes(
(predictions[0], predictions[1]), proposals)
assert len(boxes) == 1
pred_instances, _ = self.fast_rcnn_inference_GRiT(
boxes,
scores,
logits,
image_sizes,
predictor.test_score_thresh,
predictor.test_nms_thresh,
predictor.test_topk_per_image,
self.soft_nms_enabled,
)
assert len(pred_instances) == 1, "Only support one image"
for i, pred_instance in enumerate(pred_instances):
if len(pred_instance.pred_boxes) > 0:
object_features = self.object_feat_pooler(features, [pred_instance.pred_boxes])
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_output = self.text_decoder({'object_features': object_features})
if self.beam_size > 1 and self.test_task == "ObjectDet":
pred_boxes = []
pred_scores = []
pred_classes = []
pred_object_descriptions = []
for beam_id in range(self.beam_size):
pred_boxes.append(pred_instance.pred_boxes.tensor)
# object score = sqrt(objectness score x description score)
pred_scores.append((pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])[:, beam_id]) ** 0.5)
pred_classes.append(pred_instance.pred_classes)
for prediction in text_decoder_output['predictions'][:, beam_id, :]:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
merged_instances = Instances(image_sizes[0])
if torch.cat(pred_scores, dim=0).shape[0] <= predictor.test_topk_per_image:
merged_instances.scores = torch.cat(pred_scores, dim=0)
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0))
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)
merged_instances.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_scores, top_idx = torch.topk(
torch.cat(pred_scores, dim=0), predictor.test_topk_per_image)
merged_instances.scores = pred_scores
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0)[top_idx, :])
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)[top_idx]
merged_instances.pred_object_descriptions = \
ObjDescription(ObjDescription(pred_object_descriptions)[top_idx].data)
pred_instances[i] = merged_instances
else:
# object score = sqrt(objectness score x description score)
pred_instance.scores = (pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])) ** 0.5
pred_object_descriptions = []
for prediction in text_decoder_output['predictions']:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
pred_instance.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_instance.pred_object_descriptions = ObjDescription([])
return pred_instances
def forward(self, features, proposals, targets=None, targets_task="ObjectDet"):
if self.training:
proposals = self.label_and_sample_proposals(
proposals, targets)
losses = self._forward_box(features, proposals, targets, task=targets_task)
if targets[0].has('gt_masks'):
mask_losses = self._forward_mask(features, proposals)
losses.update({k: v * self.mask_weight \
for k, v in mask_losses.items()})
else:
losses.update(self._get_empty_mask_loss(device=proposals[0].objectness_logits.device))
return proposals, losses
else:
pred_instances = self._forward_box(features, proposals, task=self.test_task)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
@torch.no_grad()
def _match_and_label_boxes_GRiT(self, proposals, stage, targets):
"""
Add "gt_object_description" and "foreground" to detectron2's _match_and_label_boxes
"""
num_fg_samples, num_bg_samples = [], []
for proposals_per_image, targets_per_image in zip(proposals, targets):
match_quality_matrix = pairwise_iou(
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
)
# proposal_labels are 0 or 1
matched_idxs, proposal_labels = self.proposal_matchers[stage](match_quality_matrix)
if len(targets_per_image) > 0:
gt_classes = targets_per_image.gt_classes[matched_idxs]
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
gt_classes[proposal_labels == 0] = self.num_classes
foreground = torch.ones_like(gt_classes)
foreground[proposal_labels == 0] = 0
gt_boxes = targets_per_image.gt_boxes[matched_idxs]
gt_object_descriptions = targets_per_image.gt_object_descriptions[matched_idxs]
else:
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
foreground = torch.zeros_like(gt_classes)
gt_boxes = Boxes(
targets_per_image.gt_boxes.tensor.new_zeros((len(proposals_per_image), 4))
)
gt_object_descriptions = ObjDescription(['None' for i in range(len(proposals_per_image))])
proposals_per_image.gt_classes = gt_classes
proposals_per_image.gt_boxes = gt_boxes
proposals_per_image.gt_object_descriptions = gt_object_descriptions
proposals_per_image.foreground = foreground
num_fg_samples.append((proposal_labels == 1).sum().item())
num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1])
# Log the number of fg/bg samples in each stage
storage = get_event_storage()
storage.put_scalar(
"stage{}/roi_head/num_fg_samples".format(stage),
sum(num_fg_samples) / len(num_fg_samples),
)
storage.put_scalar(
"stage{}/roi_head/num_bg_samples".format(stage),
sum(num_bg_samples) / len(num_bg_samples),
)
return proposals
def fast_rcnn_inference_GRiT(
self,
boxes: List[torch.Tensor],
scores: List[torch.Tensor],
logits: List[torch.Tensor],
image_shapes: List[Tuple[int, int]],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled: bool,
):
result_per_image = [
self.fast_rcnn_inference_single_image_GRiT(
boxes_per_image, scores_per_image, logits_per_image, image_shape,
score_thresh, nms_thresh, topk_per_image, soft_nms_enabled
)
for scores_per_image, boxes_per_image, image_shape, logits_per_image \
in zip(scores, boxes, image_shapes, logits)
]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_inference_single_image_GRiT(
self,
boxes,
scores,
logits,
image_shape: Tuple[int, int],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled,
):
"""
Add soft NMS to detectron2's fast_rcnn_inference_single_image
"""
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
boxes = boxes[valid_mask]
scores = scores[valid_mask]
logits = logits[valid_mask]
scores = scores[:, :-1]
logits = logits[:, :-1]
num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
filter_mask = scores > score_thresh # R x K
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask]
scores = scores[filter_mask]
logits = logits[filter_mask]
# 2. Apply NMS for each class independently.
if not soft_nms_enabled:
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
else:
keep, soft_nms_scores = batched_soft_nms(
boxes,
scores,
filter_inds[:, 1],
"linear",
0.5,
nms_thresh,
0.001,
)
scores[keep] = soft_nms_scores
if topk_per_image >= 0:
keep = keep[:topk_per_image]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0]
def _get_empty_mask_loss(self, device):
if self.mask_on:
return {'loss_mask': torch.zeros(
(1, ), device=device, dtype=torch.float32)[0]}
else:
return {}
def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
boxes = [Boxes(b.detach()) for b in boxes]
proposals = []
for boxes_per_image, image_size, logit in zip(
boxes, image_sizes, logits):
boxes_per_image.clip(image_size)
if self.training:
inds = boxes_per_image.nonempty()
boxes_per_image = boxes_per_image[inds]
logit = logit[inds]
prop = Instances(image_size)
prop.proposal_boxes = boxes_per_image
prop.objectness_logits = logit
proposals.append(prop)
return proposals
def _run_stage(self, features, proposals, stage):
pool_boxes = [x.proposal_boxes for x in proposals]
box_features = self.box_pooler(features, pool_boxes)
box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
box_features = self.box_head[stage](box_features)
return self.box_predictor[stage](box_features)
import torch
from detectron2.structures import Boxes, RotatedBoxes, pairwise_iou, pairwise_iou_rotated
def soft_nms(boxes, scores, method, gaussian_sigma, linear_threshold, prune_threshold):
"""
Performs soft non-maximum suppression algorithm on axis aligned boxes
Args:
boxes (Tensor[N, 5]):
boxes where NMS will be performed. They
are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
scores (Tensor[N]):
scores for each one of the boxes
method (str):
one of ['gaussian', 'linear', 'hard']
see paper for details. users encouraged not to use "hard", as this is the
same nms available elsewhere in detectron2
gaussian_sigma (float):
parameter for Gaussian penalty function
linear_threshold (float):
iou threshold for applying linear decay. Nt from the paper
re-used as threshold for standard "hard" nms
prune_threshold (float):
boxes with scores below this threshold are pruned at each iteration.
Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
Returns:
tuple(Tensor, Tensor):
[0]: int64 tensor with the indices of the elements that have been kept
by Soft NMS, sorted in decreasing order of scores
[1]: float tensor with the re-scored scores of the elements that were kept
"""
return _soft_nms(
Boxes,
pairwise_iou,
boxes,
scores,
method,
gaussian_sigma,
linear_threshold,
prune_threshold,
)
def batched_soft_nms(
boxes, scores, idxs, method, gaussian_sigma, linear_threshold, prune_threshold
):
"""
Performs soft non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor[N, 4]):
boxes where NMS will be performed. They
are expected to be in (x1, y1, x2, y2) format
scores (Tensor[N]):
scores for each one of the boxes
idxs (Tensor[N]):
indices of the categories for each one of the boxes.
method (str):
one of ['gaussian', 'linear', 'hard']
see paper for details. users encouraged not to use "hard", as this is the
same nms available elsewhere in detectron2
gaussian_sigma (float):
parameter for Gaussian penalty function
linear_threshold (float):
iou threshold for applying linear decay. Nt from the paper
re-used as threshold for standard "hard" nms
prune_threshold (float):
boxes with scores below this threshold are pruned at each iteration.
Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
Returns:
tuple(Tensor, Tensor):
[0]: int64 tensor with the indices of the elements that have been kept
by Soft NMS, sorted in decreasing order of scores
[1]: float tensor with the re-scored scores of the elements that were kept
"""
if boxes.numel() == 0:
return (
torch.empty((0,), dtype=torch.int64, device=boxes.device),
torch.empty((0,), dtype=torch.float32, device=scores.device),
)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
return soft_nms(
boxes_for_nms, scores, method, gaussian_sigma, linear_threshold, prune_threshold
)
def _soft_nms(
box_class,
pairwise_iou_func,
boxes,
scores,
method,
gaussian_sigma,
linear_threshold,
prune_threshold,
):
"""
Soft non-max suppression algorithm.
Implementation of [Soft-NMS -- Improving Object Detection With One Line of Codec]
(https://arxiv.org/abs/1704.04503)
Args:
box_class (cls): one of Box, RotatedBoxes
pairwise_iou_func (func): one of pairwise_iou, pairwise_iou_rotated
boxes (Tensor[N, ?]):
boxes where NMS will be performed
if Boxes, in (x1, y1, x2, y2) format
if RotatedBoxes, in (x_ctr, y_ctr, width, height, angle_degrees) format
scores (Tensor[N]):
scores for each one of the boxes
method (str):
one of ['gaussian', 'linear', 'hard']
see paper for details. users encouraged not to use "hard", as this is the
same nms available elsewhere in detectron2
gaussian_sigma (float):
parameter for Gaussian penalty function
linear_threshold (float):
iou threshold for applying linear decay. Nt from the paper
re-used as threshold for standard "hard" nms
prune_threshold (float):
boxes with scores below this threshold are pruned at each iteration.
Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
Returns:
tuple(Tensor, Tensor):
[0]: int64 tensor with the indices of the elements that have been kept
by Soft NMS, sorted in decreasing order of scores
[1]: float tensor with the re-scored scores of the elements that were kept
"""
boxes = boxes.clone()
scores = scores.clone()
idxs = torch.arange(scores.size()[0])
idxs_out = []
scores_out = []
while scores.numel() > 0:
top_idx = torch.argmax(scores)
idxs_out.append(idxs[top_idx].item())
scores_out.append(scores[top_idx].item())
top_box = boxes[top_idx]
ious = pairwise_iou_func(box_class(top_box.unsqueeze(0)), box_class(boxes))[0]
if method == "linear":
decay = torch.ones_like(ious)
decay_mask = ious > linear_threshold
decay[decay_mask] = 1 - ious[decay_mask]
elif method == "gaussian":
decay = torch.exp(-torch.pow(ious, 2) / gaussian_sigma)
elif method == "hard": # standard NMS
decay = (ious < linear_threshold).float()
else:
raise NotImplementedError("{} soft nms method not implemented.".format(method))
scores *= decay
keep = scores > prune_threshold
keep[top_idx] = False
boxes = boxes[keep]
scores = scores[keep]
idxs = idxs[keep]
return torch.tensor(idxs_out).to(boxes.device), torch.tensor(scores_out).to(scores.device)
\ No newline at end of file
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
import logging
import os
import shutil
import tempfile
import fnmatch
from functools import wraps
from hashlib import sha256
from io import open
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
try:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
import torch
class LoadTextTokens(object):
def __init__(self, tokenizer, max_text_len=40, padding='do_not_pad'):
self.tokenizer = tokenizer
self.max_text_len = max_text_len
self.padding = padding
def descriptions_to_text_tokens(self, target, begin_token):
target_encoding = self.tokenizer(
target, padding=self.padding,
add_special_tokens=False,
truncation=True, max_length=self.max_text_len)
need_predict = [1] * len(target_encoding['input_ids'])
payload = target_encoding['input_ids']
if len(payload) > self.max_text_len - 2:
payload = payload[-(self.max_text_len - 2):]
need_predict = payload[-(self.max_text_len - 2):]
input_ids = [begin_token] + payload + [self.tokenizer.sep_token_id]
need_predict = [0] + need_predict + [1]
data = {
'text_tokens': torch.tensor(input_ids),
'text_lengths': len(input_ids),
'need_predict': torch.tensor(need_predict),
}
return data
def __call__(self, object_descriptions, box_features, begin_token):
text_tokens = []
text_lengths = []
need_predict = []
for description in object_descriptions:
tokens = self.descriptions_to_text_tokens(description, begin_token)
text_tokens.append(tokens['text_tokens'])
text_lengths.append(tokens['text_lengths'])
need_predict.append(tokens['need_predict'])
text_tokens = torch.cat(self.collate(text_tokens), dim=0).to(box_features.device)
text_lengths = torch.tensor(text_lengths).to(box_features.device)
need_predict = torch.cat(self.collate(need_predict), dim=0).to(box_features.device)
assert text_tokens.dim() == 2 and need_predict.dim() == 2
data = {'text_tokens': text_tokens,
'text_lengths': text_lengths,
'need_predict': need_predict}
return data
def collate(self, batch):
if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0:
if not all(b.shape == batch[0].shape for b in batch[1:]):
assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:])
shape = torch.tensor([b.shape for b in batch])
max_shape = tuple(shape.max(dim=0)[0].tolist())
batch2 = []
for b in batch:
if any(c < m for c, m in zip(b.shape, max_shape)):
b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device)
if b.dim() == 1:
b2[:b.shape[0]] = b
elif b.dim() == 2:
b2[:b.shape[0], :b.shape[1]] = b
elif b.dim() == 3:
b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b
else:
raise NotImplementedError
b = b2
batch2.append(b[None, ...])
else:
batch2 = []
for b in batch:
batch2.append(b[None, ...])
return batch2
else:
raise NotImplementedError
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import os
import json
import logging
import math
import sys
from io import open
import torch
from torch import nn
import torch.utils.checkpoint as checkpoint
from .file_utils import cached_path
logger = logging.getLogger()
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
}
def qk2attn(query, key, attention_mask, gamma):
query = query / gamma
attention_scores = torch.matmul(query, key.transpose(-1, -2))
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
return attention_scores.softmax(dim=-1)
class QK2Attention(nn.Module):
def forward(self, query, key, attention_mask, gamma):
return qk2attn(query, key, attention_mask, gamma)
LayerNormClass = torch.nn.LayerNorm
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
self.qk2attn = QK2Attention()
def transpose_for_scores(self, x):
if torch._C._get_tracing_state():
# exporter is not smart enough to detect dynamic size for some paths
x = x.view(x.shape[0], -1, self.num_attention_heads, self.attention_head_size)
else:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, head_mask=None,
history_state=None):
if history_state is not None:
x_states = torch.cat([history_state, hidden_states], dim=1)
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(x_states)
mixed_value_layer = self.value(x_states)
else:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_probs = self.qk2attn(query_layer, key_layer, attention_mask, math.sqrt(self.attention_head_size))
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
if not self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
if not self.pre_norm:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
else:
hidden_states = hidden_states + input_tensor
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
if self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, head_mask=None,
history_state=None):
if self.pre_norm:
self_outputs = self.self(self.LayerNorm(input_tensor), attention_mask, head_mask,
self.layerNorm(history_state) if history_state else history_state)
else:
self_outputs = self.self(input_tensor, attention_mask, head_mask,
history_state)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
assert config.hidden_act == 'gelu', 'Please implement other activation functions'
self.intermediate_act_fn = _gelu_python
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if not self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
if not self.pre_norm:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
else:
hidden_states = hidden_states + input_tensor
return hidden_states
class Mlp(nn.Module):
def __init__(self, config):
super().__init__()
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
self.intermediate = BertIntermediate(config)
if self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
self.output = BertOutput(config)
def forward(self, attention_output):
if not self.pre_norm:
intermediate_output = self.intermediate(attention_output)
else:
intermediate_output = self.intermediate(self.LayerNorm(attention_output))
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertLayer(nn.Module):
def __init__(self, config, use_act_checkpoint=True):
super(BertLayer, self).__init__()
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
self.use_mlp_wrapper = hasattr(config, 'use_mlp_wrapper') and config.use_mlp_wrapper
self.attention = BertAttention(config)
self.use_act_checkpoint = use_act_checkpoint
if self.use_mlp_wrapper:
self.mlp = Mlp(config)
else:
self.intermediate = BertIntermediate(config)
if self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, head_mask=None,
history_state=None):
if self.use_act_checkpoint:
attention_outputs = checkpoint.checkpoint(self.attention, hidden_states,
attention_mask, head_mask, history_state)
else:
attention_outputs = self.attention(hidden_states, attention_mask,
head_mask, history_state)
attention_output = attention_outputs[0]
if self.use_mlp_wrapper:
layer_output = self.mlp(attention_output)
else:
if not self.pre_norm:
intermediate_output = self.intermediate(attention_output)
else:
intermediate_output = self.intermediate(self.LayerNorm(attention_output))
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class BertEncoder(nn.Module):
def __init__(self, config, use_act_checkpoint=True):
super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config, use_act_checkpoint=use_act_checkpoint) for _ in range(config.num_hidden_layers)])
self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm
if self.pre_norm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask, head_mask=None,
encoder_history_states=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
history_state = None if encoder_history_states is None else encoder_history_states[i]
layer_outputs = layer_module(
hidden_states, attention_mask,
(None if head_mask is None else head_mask[i]),
history_state,
)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.pre_norm:
hidden_states = self.LayerNorm(hidden_states)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs
CONFIG_NAME = "config.json"
class PretrainedConfig(object):
""" Base class for all configuration classes.
Handle a few common parameters and methods for loading/downloading/saving configurations.
"""
pretrained_config_archive_map = {}
def __init__(self, **kwargs):
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
def save_pretrained(self, save_directory):
""" Save a configuration object to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a saved configuration `file`.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**return_unused_kwargs**: (`optional`) bool:
- If False, then this function returns just the final configuration object.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
**kwargs**: (`optional`) dict:
Dictionary of key/value pairs with which to update the configuration object after loading.
- The values in kwargs of any keys which are configuration attributes will be used
to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
Examples::
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
>>> assert config.output_attention == True
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
>>> foo=False, return_unused_kwargs=True)
>>> assert config.output_attention == True
>>> assert unused_kwargs == {'foo': False}
"""
cache_dir = kwargs.pop('cache_dir', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
# add img_layer_norm_eps, use_img_layernorm
if "img_layer_norm_eps" in kwargs:
setattr(config, "img_layer_norm_eps", kwargs["img_layer_norm_eps"])
to_remove.append("img_layer_norm_eps")
if "use_img_layernorm" in kwargs:
setattr(config, "use_img_layernorm", kwargs["use_img_layernorm"])
to_remove.append("use_img_layernorm")
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config %s", config)
if return_unused_kwargs:
return config, kwargs
else:
return config
@classmethod
def from_dict(cls, json_object):
"""Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class BertConfig(PretrainedConfig):
r"""
:class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
`BertModel`.
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs):
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
def _gelu_python(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
\ No newline at end of file
from torch import nn
import torch
import functools
from torch.nn import functional as F
import warnings
class TextualHead(nn.Module):
def __init__(self,
visual_feature_size: int, vocab_size: int, hidden_size: int):
super().__init__()
self.visual_feature_size = visual_feature_size
self.vocab_size = vocab_size
self.hidden_size = hidden_size
@property
def textual_feature_size(self):
return self.hidden_size
class WordAndPositionalEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
dropout: float = 0.0,
max_caption_length: int = 30,
padding_idx: int = 0,
):
super().__init__()
self.vocab_size = vocab_size
self.padding_idx = padding_idx
#self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx)
self.words = nn.Embedding(vocab_size, hidden_size)
# We provide no "padding index" for positional embeddings. We zero out
# the positional embeddings of padded positions as a post-processing.
self.positions = nn.Embedding(max_caption_length, hidden_size)
self.layer_norm = nn.LayerNorm(
hidden_size, eps=1e-8, elementwise_affine=True
)
self.dropout = nn.Dropout(p=dropout)
def forward(self, tokens: torch.Tensor):
position_indices = self._create_position_indices(tokens)
# shape: (batch_size, max_caption_length, hidden_size)
word_embeddings = self.words(tokens)
position_embeddings = self.positions(position_indices)
# shape: (batch_size, max_caption_length, hidden_size)
embeddings = self.layer_norm(word_embeddings + position_embeddings)
embeddings = self.dropout(embeddings)
return embeddings
@functools.lru_cache(maxsize=128)
def _create_position_indices(self, tokens: torch.Tensor):
# Create position indices of the same size as token indices.
batch_size, max_caption_length = tokens.size()
positions = torch.arange(
max_caption_length, dtype=tokens.dtype, device=tokens.device
)
# shape: (batch_size, max_caption_length)
positions = positions.unsqueeze(0).expand(batch_size, max_caption_length)
return positions
class BertEncoderAsDecoder(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, tgt, memory,
tgt_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
tgt_bi_valid_mask=None,
encoder_history_states=None,
):
assert tgt_key_padding_mask is None, 'not supported'
assert tgt_mask.dim() == 2
assert tgt_mask.shape[0] == tgt_mask.shape[1]
# tgt_mask should always be 0/negative infinity
tgt = tgt.transpose(0, 1)
memory = memory.transpose(0, 1)
hidden_states = torch.cat((memory, tgt), dim=1)
num_tgt = tgt.shape[1]
num_memory = memory.shape[1]
device = tgt.device
dtype = tgt.dtype
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
top_right = torch.full((num_memory, num_tgt), float('-inf'), device=tgt.device, dtype=dtype,)
bottom_left = torch.zeros((num_tgt, num_memory), dtype=dtype, device=tgt_mask.device,)
left = torch.cat((top_left, bottom_left), dim=0)
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
if memory_key_padding_mask is None:
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
# if it is False, it means valid. That is, it is not a padding
assert memory_key_padding_mask.dtype == torch.bool
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
zero_negative_infinity[memory_key_padding_mask] = float('-inf')
full_attention_mask = full_attention_mask.expand((memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + num_tgt))
full_attention_mask = full_attention_mask.clone()
origin_left = full_attention_mask[:, :, :num_memory]
update = zero_negative_infinity[:, None, :]
full_attention_mask[:, :, :num_memory] = origin_left + update
if tgt_bi_valid_mask is not None:
# verify the correctness
bs = full_attention_mask.shape[0]
# during inference, tgt_bi_valid_mask's length is not changed, but
# num_tgt can be increased
max_valid_target = tgt_bi_valid_mask.shape[1]
mask = tgt_bi_valid_mask[:, None, :].expand((bs, num_memory+num_tgt, max_valid_target))
full_attention_mask[:, :, num_memory:(num_memory+max_valid_target)][mask] = 0
# add axis for multi-head
full_attention_mask = full_attention_mask[:, None, :, :]
if encoder_history_states is None:
result = self.encoder(
hidden_states=hidden_states,
attention_mask=full_attention_mask,
encoder_history_states=encoder_history_states,
)
result = list(result)
result[0] = result[0][:, num_memory:].transpose(0, 1)
if self.encoder.output_hidden_states:
return result[0], result[1]
else:
# make it back-compatible
return result[0]
else:
encoder_out = self.encoder(
hidden_states=hidden_states[:, -1:],
attention_mask=full_attention_mask[:, :, -1:],
encoder_history_states=encoder_history_states,
)
result = encoder_out[0].transpose(0, 1)
if self.encoder.output_hidden_states:
return result, encoder_out[1]
else:
return result
def create_transformer(decoder_type, norm_type,
textual_feature_size,
attention_heads,
feedforward_size,
dropout,
num_layers,
output_hidden_states=False,
use_mlp_wrapper=None,
use_act_checkpoint=True,
):
assert norm_type in ['post', 'pre']
if decoder_type is None:
LayerClass = (
nn.TransformerDecoderLayer
if norm_type == "post"
else PreNormTransformerDecoderLayer
)
_layer = LayerClass(
textual_feature_size,
attention_heads,
dim_feedforward=feedforward_size,
dropout=dropout,
activation="gelu",
)
return nn.TransformerDecoder(_layer, num_layers)
elif decoder_type == 'bert_en':
from .modeling_bert import BertConfig, BertEncoder
config = BertConfig(
vocab_size_or_config_json_file=30522,
hidden_size=textual_feature_size,
num_hidden_layers=num_layers,
num_attention_heads=attention_heads,
intermediate_size=feedforward_size,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
layer_norm_eps=1e-12,
)
config.pre_norm = (norm_type == 'pre')
config.use_mlp_wrapper = use_mlp_wrapper
config.output_hidden_states = output_hidden_states
encoder = BertEncoder(config, use_act_checkpoint=use_act_checkpoint)
return BertEncoderAsDecoder(encoder)
class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer):
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
# fmt: off
# We use the members (modules) from super-class, just the order of
# operations is changed here. First layernorm, then attention.
tgt2 = self.norm1(tgt)
tgt2, _ = self.self_attn(
tgt2, tgt2, tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
tgt = tgt + self.dropout1(tgt2)
# Layernorm first, then decoder attention.
tgt2 = self.norm2(tgt)
tgt2, _ = self.multihead_attn(
tgt2, memory, memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)
tgt = tgt + self.dropout2(tgt2)
# Layernorm first, then transformation through feedforward network.
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class TransformerDecoderTextualHead(TextualHead):
def __init__(
self,
object_feature_size: int,
vocab_size: int,
hidden_size: int,
num_layers: int,
attention_heads: int,
feedforward_size: int,
dropout: float = 0.1,
norm_type: str = "post",
mask_future_positions: bool = True,
max_caption_length: int = 1024,
padding_idx: int = 0,
decoder_type=None,
not_tie_weight=None,
output_hidden_states=None,
use_mlp_wrapper=None,
use_act_checkpoint=True,
):
super().__init__(object_feature_size, vocab_size, hidden_size)
self.num_layers = num_layers
self.attention_heads = attention_heads
self.feedforward_size = feedforward_size
self.dropout = dropout
assert mask_future_positions
self.padding_idx = padding_idx
self.object_feature_projection = nn.Sequential(
nn.Linear(object_feature_size, self.textual_feature_size),
nn.LayerNorm(self.textual_feature_size))
self.embedding = WordAndPositionalEmbedding(
self.vocab_size,
self.textual_feature_size,
dropout=dropout,
max_caption_length=max_caption_length,
padding_idx=padding_idx,
)
self.transformer = create_transformer(
decoder_type=decoder_type,
norm_type=norm_type,
textual_feature_size=self.textual_feature_size,
attention_heads=self.attention_heads,
feedforward_size=self.feedforward_size,
dropout=dropout,
num_layers=self.num_layers,
output_hidden_states=output_hidden_states,
use_mlp_wrapper=use_mlp_wrapper,
use_act_checkpoint=use_act_checkpoint,
)
self.apply(self._init_weights)
# Create an output linear layer and tie the input and output word
# embeddings to reduce parametejs.
self.output = nn.Linear(self.textual_feature_size, vocab_size)
if not not_tie_weight:
self.output.weight = self.embedding.words.weight
@staticmethod
def _init_weights(module):
"""Initialize weights like BERT - N(0.0, 0.02), bias = 0."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(
self,
hidden_states,
text_tokens,
):
projected_object_features = self.object_feature_projection(hidden_states) if hidden_states is not None else None
batch_size, max_text_length = text_tokens.size()
text_embeddings = self.embedding(text_tokens)
# An additive mask for masking the future (one direction).
uni_mask_zero_neg = self._generate_future_mask(
max_text_length, text_embeddings.dtype, text_embeddings.device
)
# We transpose the first two dimensions of tokens embeddings and visual
# features, as required by decoder.
text_embeddings = text_embeddings.transpose(0, 1)
projected_object_features = projected_object_features.transpose(0, 1)
# if transformer here is the pytorch/decoder, there is no chance, the
# output is always tensor
trans_out = self.transformer(
text_embeddings,
projected_object_features,
tgt_mask=uni_mask_zero_neg,
)
if isinstance(trans_out, tuple):
textual_features = trans_out[0]
else:
assert isinstance(trans_out, torch.Tensor)
textual_features = trans_out
# Undo the transpose and bring batch to dim 0.
# shape: (batch_size, max_caption_length, hidden_size)
textual_features = textual_features.transpose(0, 1)
# shape: (batch_size, max_caption_length, vocab_size)
output_logits = self.output(textual_features)
if isinstance(trans_out, tuple):
return output_logits, trans_out[1]
else:
return output_logits
def _generate_future_mask(
self, size: int, dtype: torch.dtype, device: torch.device
):
# Default mask is for forward direction. Flip for backward direction.
mask = torch.triu(
torch.ones(size, size, device=device, dtype=dtype), diagonal=1
)
mask = mask.masked_fill(mask == 1, float("-inf"))
return mask
class AutoRegressiveBeamSearch(object):
def __init__(
self,
end_token_id: int,
max_steps: int = 50,
beam_size: int = 5,
objectdet=True,
per_node_beam_size: int = 2,
):
self._eos_index = end_token_id
self.max_steps = max_steps
self.beam_size = beam_size
self.objectdet = objectdet
self.per_node_beam_size = per_node_beam_size or beam_size
def search(self, begin_tokens, step):
if self.beam_size > 1 and self.objectdet:
only_return_best = False
else:
only_return_best = True
batch_size = begin_tokens.size()[0]
predictions = begin_tokens.unsqueeze(1).expand((batch_size, self.beam_size, begin_tokens.shape[-1]))
# Calculate the first timestep. This is done outside the main loop
# because we are going from a single decoder input (the output from the
# encoder) to the top `beam_size` decoder outputs. On the other hand,
# within the main loop we are going from the `beam_size` elements of the
# beam to `beam_size`^2 candidates from which we will select the top
# `beam_size` elements for the next iteration.
# shape: (batch_size, num_classes)
start_class_logits = step(begin_tokens)
# Convert logits to logprobs.
# shape: (batch_size * beam_size, vocab_size)
start_class_logprobs = F.log_softmax(start_class_logits, dim=1)
num_classes = start_class_logprobs.size()[1]
# shape: (batch_size, beam_size), (batch_size, beam_size)
start_top_logprobs, start_predicted_classes = start_class_logprobs.topk(
self.beam_size
)
if (
self.beam_size == 1
and (start_predicted_classes == self._eos_index).all()
):
warnings.warn(
"Empty object description predicted. You may want to increase beam"
"size or ensure your step function is working properly.",
RuntimeWarning,
)
if only_return_best:
return start_predicted_classes, start_top_logprobs
else:
return start_predicted_classes.unsqueeze(-1), start_top_logprobs
# The log probs for the last time step.
# shape: (batch_size, beam_size)
last_logprobs = start_top_logprobs
# shape: (batch_size, beam_size, sequence_length)
predictions = torch.cat([predictions, start_predicted_classes.unsqueeze(-1)], dim=-1)
# Log probability tensor that mandates that the end token is selected.
# shape: (batch_size * beam_size, num_classes)
logprobs_after_end = start_class_logprobs.new_full(
(batch_size * self.beam_size, num_classes), float("-inf")
)
logprobs_after_end[:, self._eos_index] = 0.0
logits_after_end = start_class_logprobs.new_full(
(batch_size * self.beam_size, num_classes), float("-inf")
)
logits_after_end[:, self._eos_index] = 0
while predictions.shape[-1] < self.max_steps:
# shape: (batch_size * beam_size,)
last_predictions = predictions[:, :, -1].reshape(batch_size * self.beam_size)
# If every predicted token from the last step is `self._eos_index`,
# then we can stop early.
if (last_predictions == self._eos_index).all():
break
predictions_so_far = predictions.view(
batch_size * self.beam_size, -1
)
# shape: (batch_size * beam_size, num_classes)
class_logits = step(predictions_so_far)
# Set logprobs of last predicted tokens as high negative value to avoid
# repetition in description.
class_logits = class_logits.scatter(1, predictions_so_far[:, -1].view((-1, 1)), -10000)
# shape: (batch_size * beam_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.beam_size, num_classes
)
# Here we are finding any beams where we predicted the end token in
# the previous timestep and replacing the distribution with a
# one-hot distribution, forcing the beam to predict the end token
# this timestep as well.
class_logits = torch.where(
last_predictions_expanded == self._eos_index,
logits_after_end,
class_logits,
)
# Convert logits to logprobs.
# shape: (batch_size * beam_size, vocab_size)
class_logprobs = F.log_softmax(class_logits, dim=1)
# shape (both): (batch_size * beam_size, per_node_beam_size)
top_logprobs, predicted_classes = class_logprobs.topk(
self.per_node_beam_size
)
# Here we expand the last log probs to `(batch_size * beam_size,
# per_node_beam_size)` so that we can add them to the current log
# probs for this timestep. This lets us maintain the log
# probability of each element on the beam.
# shape: (batch_size * beam_size, per_node_beam_size)
expanded_last_logprobs = (
last_logprobs.unsqueeze(2)
.expand(batch_size, self.beam_size, self.per_node_beam_size)
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
)
# shape: (batch_size * beam_size, per_node_beam_size)
summed_top_logprobs = top_logprobs + expanded_last_logprobs
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_summed = summed_top_logprobs.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_predicted_classes = predicted_classes.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# Append the predictions to the current beam.
reshaped_beam = (
predictions.view(batch_size * self.beam_size, 1, -1)
.repeat(1, self.per_node_beam_size, 1)
.reshape(batch_size, self.beam_size * self.per_node_beam_size, -1)
)
# batch_size, (beam_size * per_node_beach_size), #token
reshaped_beam = torch.cat([reshaped_beam, reshaped_predicted_classes.unsqueeze(-1)], dim=-1)
# Keep only the top `beam_size` beam indices.
# shape: (batch_size, beam_size), (batch_size, beam_size)
restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk(
self.beam_size
)
predictions = reshaped_beam.gather(
1, restricted_beam_indices.unsqueeze(-1).repeat(1,1,reshaped_beam.shape[-1])
)
# shape: (batch_size, beam_size)
last_logprobs = restricted_beam_logprobs
if not torch.isfinite(last_logprobs).all():
warnings.warn(
"Infinite log probs encountered. Some final descriptions may not "
"make sense. This can happen when the beam size is larger than"
" the number of valid (non-zero probability) transitions that "
"the step function produces.",
RuntimeWarning,
)
# Optionally select best beam and its logprobs.
if only_return_best:
# shape: (batch_size, sequence_length)
predictions = predictions[:, 0, :]
last_logprobs = last_logprobs[:, 0]
num_valid = (predictions != self._eos_index).sum(dim=-1)
num_valid += (predictions == self._eos_index).sum(dim=-1) > 0
num_valid = num_valid - begin_tokens.shape[1]
num_valid = num_valid.clip(min=1)
last_logprobs = last_logprobs / num_valid
return predictions, last_logprobs
class GRiTTextDecoder(nn.Module):
def __init__(
self,
transformer,
begin_token_id=101,
beamsearch_decode=None,
loss_type=None,
tokenizer=None,
):
super().__init__()
self.textual = transformer
self.padding_idx = self.textual.padding_idx
self.begin_token_id = begin_token_id
self.beamsearch_decode = beamsearch_decode
self.tokenizer = tokenizer
if loss_type is None:
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
elif loss_type == 'smooth':
self.loss = SmoothLabelCrossEntropyLoss(ignore_index=self.padding_idx)
else:
raise NotImplementedError(loss_type)
def forward(self, batch):
object_features = batch['object_features']
if self.training:
caption_token_input = batch["text_tokens"]
output_logits = self.textual(
object_features,
caption_token_input,
)
if 'need_predict' in batch:
# in place should also be good, but we do not choose that for
# safety as we may use it in prediction results in future
target = batch["text_tokens"].clone()
target[batch['need_predict'] == 0] = self.padding_idx
else:
target = batch["text_tokens"]
feat = output_logits[:, :-1].contiguous()
target = target[:, 1:].contiguous()
feat = feat.view(-1, self.textual.vocab_size)
target = target.view(-1)
valid_mask = target != self.padding_idx
target = target[valid_mask]
feat = feat[valid_mask]
loss = self.loss(feat, target)
return loss
else:
output_dict = self.infer(object_features)
return output_dict
def infer(self, object_features):
batch_size = object_features.size(0)
begin_tokens = object_features.new_full(
(batch_size, 1), self.begin_token_id
).long()
decoding_step = functools.partial(
self.decoding_step, object_features
)
object_description_tokens, logprobs = self.beamsearch_decode.search(
begin_tokens, decoding_step
)
output_dict = {
'predictions': object_description_tokens,
'logprobs': logprobs,
}
return output_dict
def decoding_step(self, object_features, partial_text):
batch_size = object_features.shape[0]
beam_size = int(partial_text.size(0) / batch_size)
if beam_size > 1:
batch_size, num_token, channels = object_features.size()
object_features = object_features.unsqueeze(1).repeat(1, beam_size, 1, 1)
object_features = object_features.view(
batch_size * beam_size, num_token, channels
)
text_lengths = torch.ones_like(partial_text)
if len(text_lengths.size()) != 2:
partial_text = partial_text.unsqueeze(1)
# shape: (batch_size * beam_size, partial_caption_length, vocab_size)
logits = self.textual(
object_features,
partial_text,
)
return logits[:, -1, :].float()
class SmoothLabelCrossEntropyLoss(nn.Module):
def __init__(self, eps=0.1, log_prefix='', ignore_index=None):
super().__init__()
self.eps = eps
self.log_soft = nn.LogSoftmax(dim=1)
self.kl = nn.KLDivLoss(reduction='none')
self.iter = 0
self.max_loss = 0
self.min_loss = 0
self.log_prefix = log_prefix
self.ignore_index = ignore_index
def forward(self, feature, target):
feature = feature.float()
if self.ignore_index is not None:
valid_mask = target != self.ignore_index
target = target[valid_mask]
feature = feature[valid_mask]
assert target.numel() > 0
self.iter += 1
eps = self.eps
n_class = feature.size(1)
one_hot = torch.zeros_like(feature).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = self.log_soft(feature)
loss = self.kl(log_prb, one_hot)
return loss.sum(dim=1).mean()
import torch
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
class Visualizer_GRiT(Visualizer):
def __init__(self, image, instance_mode=None):
super().__init__(image, instance_mode=instance_mode)
def draw_instance_predictions(self, predictions):
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
scores = predictions.scores if predictions.has("scores") else None
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
object_description = predictions.pred_object_descriptions.data
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
colors = [
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
]
alpha = 0.8
else:
colors = None
alpha = 0.5
if self._instance_mode == ColorMode.IMAGE_BW:
self.output.reset_image(
self._create_grayscale_image(
(predictions.pred_masks.any(dim=0) > 0).numpy()
if predictions.has("pred_masks")
else None
)
)
alpha = 0.3
self.overlay_instances(
masks=None,
boxes=boxes,
labels=object_description,
keypoints=None,
assigned_colors=colors,
alpha=alpha,
)
return self.output
class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE):
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.predictor = DefaultPredictor(cfg)
def run_on_image(self, image):
predictions = self.predictor(image)
# Convert image from OpenCV BGR format to Matplotlib RGB format.
image = image[:, :, ::-1]
visualizer = Visualizer_GRiT(image, instance_mode=self.instance_mode)
instances = predictions["instances"].to(self.cpu_device)
vis_output = visualizer.draw_instance_predictions(predictions=instances)
return predictions, vis_output
class BatchPredictor(DefaultPredictor):
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, original_images):
input_list=[]
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
for original_image in original_images:
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
input = {"image": image, "height": height, "width": width}
input_list.append(input)
predictions = self.model(input_list)
return predictions
class BatchVisualizationDemo(object):
def __init__(self, cfg):
self.cpu_device = torch.device("cpu")
self.predictor = BatchPredictor(cfg)
def run_on_images(self, images):
predictions = self.predictor(images)
return predictions
\ No newline at end of file
version: 2.1
# -------------------------------------------------------------------------------------
# Environments to run the jobs in
# -------------------------------------------------------------------------------------
cpu: &cpu
machine:
image: ubuntu-2004:202107-02
resource_class: medium
gpu: &gpu
machine:
# NOTE: use a cuda vesion that's supported by all our pytorch versions
image: ubuntu-1604-cuda-11.1:202012-01
resource_class: gpu.nvidia.small
windows-cpu: &windows_cpu
machine:
resource_class: windows.medium
image: windows-server-2019-vs2019:stable
shell: powershell.exe
# windows-gpu: &windows_gpu
# machine:
# resource_class: windows.gpu.nvidia.medium
# image: windows-server-2019-nvidia:stable
version_parameters: &version_parameters
parameters:
pytorch_version:
type: string
torchvision_version:
type: string
pytorch_index:
type: string
# use test wheels index to have access to RC wheels
# https://download.pytorch.org/whl/test/torch_test.html
default: "https://download.pytorch.org/whl/torch_stable.html"
python_version: # NOTE: only affect linux
type: string
default: '3.6.8'
environment:
PYTORCH_VERSION: << parameters.pytorch_version >>
TORCHVISION_VERSION: << parameters.torchvision_version >>
PYTORCH_INDEX: << parameters.pytorch_index >>
PYTHON_VERSION: << parameters.python_version>>
# point datasets to ~/.torch so it's cached in CI
DETECTRON2_DATASETS: ~/.torch/datasets
# -------------------------------------------------------------------------------------
# Re-usable commands
# -------------------------------------------------------------------------------------
# install_nvidia_driver: &install_nvidia_driver
# - run:
# name: Install nvidia driver
# working_directory: ~/
# command: |
# wget -q 'https://s3.amazonaws.com/ossci-linux/nvidia_driver/NVIDIA-Linux-x86_64-430.40.run'
# sudo /bin/bash ./NVIDIA-Linux-x86_64-430.40.run -s --no-drm
# nvidia-smi
add_ssh_keys: &add_ssh_keys
# https://circleci.com/docs/2.0/add-ssh-key/
- add_ssh_keys:
fingerprints:
- "e4:13:f2:22:d4:49:e8:e4:57:5a:ac:20:2f:3f:1f:ca"
install_python: &install_python
- run:
name: Install Python
working_directory: ~/
command: |
# upgrade pyenv
cd /opt/circleci/.pyenv/plugins/python-build/../.. && git pull && cd -
pyenv install -s $PYTHON_VERSION
pyenv global $PYTHON_VERSION
python --version
which python
pip install --upgrade pip
setup_venv: &setup_venv
- run:
name: Setup Virtual Env
working_directory: ~/
command: |
python -m venv ~/venv
echo ". ~/venv/bin/activate" >> $BASH_ENV
. ~/venv/bin/activate
python --version
which python
which pip
pip install --upgrade pip
setup_venv_win: &setup_venv_win
- run:
name: Setup Virutal Env for Windows
command: |
pip install virtualenv
python -m virtualenv env
.\env\Scripts\activate
python --version
which python
which pip
install_linux_dep: &install_linux_dep
- run:
name: Install Dependencies
command: |
# disable crash coredump, so unittests fail fast
sudo systemctl stop apport.service
# install from github to get latest; install iopath first since fvcore depends on it
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/iopath'
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
# Don't use pytest-xdist: cuda tests are unstable under multi-process workers.
pip install --progress-bar off ninja opencv-python-headless pytest tensorboard pycocotools
pip install --progress-bar off torch==$PYTORCH_VERSION -f $PYTORCH_INDEX
if [[ "$TORCHVISION_VERSION" == "master" ]]; then
pip install git+https://github.com/pytorch/vision.git
else
pip install --progress-bar off torchvision==$TORCHVISION_VERSION -f $PYTORCH_INDEX
fi
python -c 'import torch; print("CUDA:", torch.cuda.is_available())'
gcc --version
install_detectron2: &install_detectron2
- run:
name: Install Detectron2
command: |
# Remove first, in case it's in the CI cache
pip uninstall -y detectron2
pip install --progress-bar off -e .[all]
python -m detectron2.utils.collect_env
./datasets/prepare_for_tests.sh
run_unittests: &run_unittests
- run:
name: Run Unit Tests
command: |
pytest -v --durations=15 tests # parallel causes some random failures
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
jobs:
linux_cpu_tests:
<<: *cpu
<<: *version_parameters
working_directory: ~/detectron2
steps:
- checkout
# Cache the venv directory that contains python, dependencies, and checkpoints
# Refresh the key when dependencies should be updated (e.g. when pytorch releases)
- restore_cache:
keys:
- cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
- <<: *install_python
- <<: *install_linux_dep
- <<: *install_detectron2
- <<: *run_unittests
- save_cache:
paths:
- /opt/circleci/.pyenv
- ~/.torch
key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
linux_gpu_tests:
<<: *gpu
<<: *version_parameters
working_directory: ~/detectron2
steps:
- checkout
- restore_cache:
keys:
- cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
- <<: *install_python
- <<: *install_linux_dep
- <<: *install_detectron2
- <<: *run_unittests
- save_cache:
paths:
- /opt/circleci/.pyenv
- ~/.torch
key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
windows_cpu_build:
<<: *windows_cpu
<<: *version_parameters
steps:
- <<: *add_ssh_keys
- checkout
- <<: *setup_venv_win
# Cache the env directory that contains dependencies
- restore_cache:
keys:
- cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210404
- run:
name: Install Dependencies
command: |
pip install certifi --ignore-installed # required on windows to workaround some cert issue
pip install numpy cython # required on windows before pycocotools
pip install opencv-python-headless pytest-xdist pycocotools tensorboard
pip install -U git+https://github.com/facebookresearch/iopath
pip install -U git+https://github.com/facebookresearch/fvcore
pip install torch==$env:PYTORCH_VERSION torchvision==$env:TORCHVISION_VERSION -f $env:PYTORCH_INDEX
- save_cache:
paths:
- env
key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210404
- <<: *install_detectron2
# TODO: unittest fails for now
workflows:
version: 2
regular_test:
jobs:
- linux_cpu_tests:
name: linux_cpu_tests_pytorch1.10
pytorch_version: '1.10.0+cpu'
torchvision_version: '0.11.1+cpu'
- linux_gpu_tests:
name: linux_gpu_tests_pytorch1.8
pytorch_version: '1.8.1+cu111'
torchvision_version: '0.9.1+cu111'
- linux_gpu_tests:
name: linux_gpu_tests_pytorch1.9
pytorch_version: '1.9+cu111'
torchvision_version: '0.10+cu111'
- linux_gpu_tests:
name: linux_gpu_tests_pytorch1.10
pytorch_version: '1.10+cu111'
torchvision_version: '0.11.1+cu111'
- linux_gpu_tests:
name: linux_gpu_tests_pytorch1.10_python39
pytorch_version: '1.10+cu111'
torchvision_version: '0.11.1+cu111'
python_version: '3.9.6'
- windows_cpu_build:
pytorch_version: '1.10+cpu'
torchvision_version: '0.11.1+cpu'
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
# This is an example .flake8 config, used when developing *Black* itself.
# Keep in sync with setup.cfg which is used for source packages.
[flake8]
ignore = W503, E203, E221, C901, C408, E741, C407, B017
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
exclude = build
per-file-ignores =
**/__init__.py:F401,F403,E402
**/configs/**.py:F401,E402
configs/**.py:F401,E402
**/tests/config/**.py:F401,E402
tests/config/**.py:F401,E402
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment