Commit 46260e34 authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #2006 failed with stages
in 0 seconds
import torch
import torch.nn as nn
import torch.cuda.amp as amp
from src.core import register
import src.misc.dist as dist
__all__ = ['GradScaler']
GradScaler = register(amp.grad_scaler.GradScaler)
"""
reference:
https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py#L404
by lyuwenyu
"""
import torch
import torch.nn as nn
import math
from copy import deepcopy
from src.core import register
import src.misc.dist as dist
__all__ = ['ModelEMA']
@register
class ModelEMA(object):
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model: nn.Module, decay: float=0.9999, warmups: int=2000):
super().__init__()
# Create EMA
self.module = deepcopy(dist.de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.module.half() # FP16 EMA
self.decay = decay
self.warmups = warmups
self.updates = 0 # number of EMA updates
# self.filter_no_grad = filter_no_grad
self.decay_fn = lambda x: decay * (1 - math.exp(-x / warmups)) # decay exponential ramp (to help early epochs)
for p in self.module.parameters():
p.requires_grad_(False)
def update(self, model: nn.Module):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay_fn(self.updates)
msd = dist.de_parallel(model).state_dict()
for k, v in self.module.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
self.copy_attr(self.module, model, include, exclude)
@staticmethod
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
def state_dict(self, ):
return dict(module=self.module.state_dict(), updates=self.updates, warmups=self.warmups)
def load_state_dict(self, state):
self.module.load_state_dict(state['module'])
if 'updates' in state:
self.updates = state['updates']
def forwad(self, ):
raise RuntimeError('ema...')
def extra_repr(self) -> str:
return f'decay={self.decay}, warmups={self.warmups}'
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu", use_buffers=True):
self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000))
def ema_avg(avg_model_param, model_param, num_averaged):
decay = self.decay_fn(num_averaged)
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=use_buffers)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from src.core import register
__all__ = ['AdamW', 'SGD', 'Adam', 'MultiStepLR', 'CosineAnnealingLR', 'OneCycleLR', 'LambdaLR']
SGD = register(optim.SGD)
Adam = register(optim.Adam)
AdamW = register(optim.AdamW)
MultiStepLR = register(lr_scheduler.MultiStepLR)
CosineAnnealingLR = register(lr_scheduler.CosineAnnealingLR)
OneCycleLR = register(lr_scheduler.OneCycleLR)
LambdaLR = register(lr_scheduler.LambdaLR)
"""by lyuwenyu
"""
from .solver import BaseSolver
from .det_solver import DetSolver
from typing import Dict
TASKS :Dict[str, BaseSolver] = {
'detection': DetSolver,
}
\ No newline at end of file
"""
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
https://github.com/facebookresearch/detr/blob/main/engine.py
by lyuwenyu
"""
import math
import os
import sys
import pathlib
from typing import Iterable
import torch
import torch.amp
from src.data import CocoEvaluator
from src.misc import (MetricLogger, SmoothedValue, reduce_dict)
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0, **kwargs):
model.train()
criterion.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
# metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = kwargs.get('print_freq', 10)
ema = kwargs.get('ema', None)
scaler = kwargs.get('scaler', None)
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
if scaler is not None:
with torch.autocast(device_type=str(device), cache_enabled=True):
outputs = model(samples, targets)
with torch.autocast(device_type=str(device), enabled=False):
loss_dict = criterion(outputs, targets)
loss = sum(loss_dict.values())
scaler.scale(loss).backward()
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
outputs = model(samples, targets)
loss_dict = criterion(outputs, targets)
loss = sum(loss_dict.values())
optimizer.zero_grad()
loss.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# ema
if ema is not None:
ema.update(model)
loss_dict_reduced = reduce_dict(loss_dict)
loss_value = sum(loss_dict_reduced.values())
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
metric_logger.update(loss=loss_value, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, base_ds, device, output_dir):
model.eval()
criterion.eval()
metric_logger = MetricLogger(delimiter=" ")
# metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
# iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
iou_types = postprocessors.iou_types
coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
panoptic_evaluator = None
# if 'panoptic' in postprocessors.keys():
# panoptic_evaluator = PanopticEvaluator(
# data_loader.dataset.ann_file,
# data_loader.dataset.ann_folder,
# output_dir=os.path.join(output_dir, "panoptic_eval"),
# )
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# with torch.autocast(device_type=str(device)):
# outputs = model(samples)
outputs = model(samples)
# loss_dict = criterion(outputs, targets)
# weight_dict = criterion.weight_dict
# # reduce losses over all GPUs for logging purposes
# loss_dict_reduced = reduce_dict(loss_dict)
# loss_dict_reduced_scaled = {k: v * weight_dict[k]
# for k, v in loss_dict_reduced.items() if k in weight_dict}
# loss_dict_reduced_unscaled = {f'{k}_unscaled': v
# for k, v in loss_dict_reduced.items()}
# metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
# **loss_dict_reduced_scaled,
# **loss_dict_reduced_unscaled)
# metric_logger.update(class_error=loss_dict_reduced['class_error'])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors(outputs, orig_target_sizes)
# results = postprocessors(outputs, targets)
# if 'segm' in postprocessors.keys():
# target_sizes = torch.stack([t["size"] for t in targets], dim=0)
# results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
# if panoptic_evaluator is not None:
# res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
# for i, target in enumerate(targets):
# image_id = target["image_id"].item()
# file_name = f"{image_id:012d}.png"
# res_pano[i]["image_id"] = image_id
# res_pano[i]["file_name"] = file_name
# panoptic_evaluator.update(res_pano)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
# panoptic_res = None
# if panoptic_evaluator is not None:
# panoptic_res = panoptic_evaluator.summarize()
stats = {}
# stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in iou_types:
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in iou_types:
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
# if panoptic_res is not None:
# stats['PQ_all'] = panoptic_res["All"]
# stats['PQ_th'] = panoptic_res["Things"]
# stats['PQ_st'] = panoptic_res["Stuff"]
return stats, coco_evaluator
'''
by lyuwenyu
'''
import time
import json
import datetime
import torch
from src.misc import dist
from src.data import get_coco_api_from_dataset
from .solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
class DetSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
# best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, }
best_stat = {'epoch': -1, }
start_time = time.time()
for epoch in range(self.last_epoch + 1, args.epoches):
if dist.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
self.lr_scheduler.step()
if self.output_dir:
checkpoint_paths = [self.output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_step == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)
# TODO
for k in test_stats.keys():
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if self.output_dir and dist.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, base_ds, self.device, self.output_dir)
if self.output_dir:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return
"""by lyuwenyu
"""
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
from typing import Dict
from src.misc import dist
from src.core import BaseConfig
class BaseSolver(object):
def __init__(self, cfg: BaseConfig) -> None:
self.cfg = cfg
def setup(self, ):
'''Avoid instantiating unnecessary classes
'''
cfg = self.cfg
device = cfg.device
self.device = device
self.last_epoch = cfg.last_epoch
self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
self.criterion = cfg.criterion.to(device)
self.postprocessor = cfg.postprocessor
# NOTE (lvwenyu): should load_tuning_state before ema instance building
if self.cfg.tuning:
print(f'Tuning checkpoint from {self.cfg.tuning}')
self.load_tuning_state(self.cfg.tuning)
self.scaler = cfg.scaler
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train(self, ):
self.setup()
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
# NOTE instantiating order
if self.cfg.resume:
print(f'Resume checkpoint from {self.cfg.resume}')
self.resume(self.cfg.resume)
self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \
shuffle=self.cfg.train_dataloader.shuffle)
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
def eval(self, ):
self.setup()
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
if self.cfg.resume:
print(f'resume from {self.cfg.resume}')
self.resume(self.cfg.resume)
def state_dict(self, last_epoch):
'''state dict
'''
state = {}
state['model'] = dist.de_parallel(self.model).state_dict()
state['date'] = datetime.now().isoformat()
# TODO
state['last_epoch'] = last_epoch
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
state['lr_scheduler'] = self.lr_scheduler.state_dict()
# state['last_epoch'] = self.lr_scheduler.last_epoch
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if self.scaler is not None:
state['scaler'] = self.scaler.state_dict()
return state
def load_state_dict(self, state):
'''load state dict
'''
# TODO
if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
self.last_epoch = state['last_epoch']
print('Loading last_epoch')
if getattr(self, 'model', None) and 'model' in state:
if dist.is_parallel(self.model):
self.model.module.load_state_dict(state['model'])
else:
self.model.load_state_dict(state['model'])
print('Loading model.state_dict')
if getattr(self, 'ema', None) and 'ema' in state:
self.ema.load_state_dict(state['ema'])
print('Loading ema.state_dict')
if getattr(self, 'optimizer', None) and 'optimizer' in state:
self.optimizer.load_state_dict(state['optimizer'])
print('Loading optimizer.state_dict')
if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print('Loading lr_scheduler.state_dict')
if getattr(self, 'scaler', None) and 'scaler' in state:
self.scaler.load_state_dict(state['scaler'])
print('Loading scaler.state_dict')
def save(self, path):
'''save state
'''
state = self.state_dict()
dist.save_on_master(state, path)
def resume(self, path):
'''load resume
'''
# for cuda:0 memory
state = torch.load(path, map_location='cpu')
self.load_state_dict(state)
def load_tuning_state(self, path,):
"""only load model for tuning and skip missed/dismatched keys
"""
if 'http' in path:
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
state = torch.load(path, map_location='cpu')
module = dist.de_parallel(self.model)
# TODO hard code
if 'ema' in state:
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
else:
stat, infos = self._matched_state(module.state_dict(), state['model'])
module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
@staticmethod
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
missed_list = []
unmatched_list = []
matched_state = {}
for k, v in state.items():
if k in params:
if v.shape == params[k].shape:
matched_state[k] = params[k]
else:
unmatched_list.append(k)
else:
missed_list.append(k)
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
def fit(self, ):
raise NotImplementedError('')
def val(self, ):
raise NotImplementedError('')
"""by lyuwenyu
"""
from .rtdetr import *
from .hybrid_encoder import *
from .rtdetr_decoder import *
from .rtdetr_postprocessor import *
from .rtdetr_criterion import *
from .matcher import *
'''
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
https://github.com/facebookresearch/detr/blob/main/util/box_ops.py
'''
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
\ No newline at end of file
"""by lyuwenyu
"""
import torch
from .utils import inverse_sigmoid
from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
def get_contrastive_denoising_training_group(targets,
num_classes,
num_queries,
class_embed,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,):
"""cnd"""
if num_denoising <= 0:
return None, None, None, None
num_gts = [len(t['labels']) for t in targets]
device = targets[0]['labels'].device
max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
num_group = num_denoising // max_gt_num
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(num_gts)
input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
for i in range(bs):
num_gt = num_gts[i]
if num_gt > 0:
input_query_class[i, :num_gt] = targets[i]['labels']
input_query_bbox[i, :num_gt] = targets[i]['boxes']
pad_gt_mask[i, :num_gt] = 1
# each group has positive and negative queries.
input_query_class = input_query_class.tile([1, 2 * num_group])
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
# positive and negative mask
negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
negative_gt_mask[:, max_gt_num:] = 1
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
positive_gt_mask = 1 - negative_gt_mask
# contrastive denoising training positive index
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
# total denoising queries
num_denoising = int(max_gt_num * 2 * num_group)
if label_noise_ratio > 0:
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
# randomly put a new one here
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
# if label_noise_ratio > 0:
# input_query_class = input_query_class.flatten()
# pad_gt_mask = pad_gt_mask.flatten()
# # half of bbox prob
# # mask = torch.rand(input_query_class.shape, device=device) < (label_noise_ratio * 0.5)
# mask = torch.rand_like(input_query_class) < (label_noise_ratio * 0.5)
# chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
# # randomly put a new one here
# new_label = torch.randint_like(chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
# # input_query_class.scatter_(dim=0, index=chosen_idx, value=new_label)
# input_query_class[chosen_idx] = new_label
# input_query_class = input_query_class.reshape(bs, num_denoising)
# pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
if box_noise_scale > 0:
known_bbox = box_cxcywh_to_xyxy(input_query_bbox)
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(input_query_bbox)
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
input_query_bbox = box_xyxy_to_cxcywh(known_bbox)
input_query_bbox = inverse_sigmoid(input_query_bbox)
# class_embed = torch.concat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=device)])
# input_query_class = torch.gather(
# class_embed, input_query_class.flatten(),
# axis=0).reshape(bs, num_denoising, -1)
# input_query_class = class_embed(input_query_class.flatten()).reshape(bs, num_denoising, -1)
input_query_class = class_embed(input_query_class)
tgt_size = num_denoising + num_queries
# attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0
attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
# match query cannot see the reconstruction
attn_mask[num_denoising:, :num_denoising] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
if i == num_group - 1:
attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
else:
attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
dn_meta = {
"dn_positive_idx": dn_positive_idx,
"dn_num_group": num_group,
"dn_num_split": [num_denoising, num_queries]
}
# print(input_query_class.shape) # torch.Size([4, 196, 256])
# print(input_query_bbox.shape) # torch.Size([4, 196, 4])
# print(attn_mask.shape) # torch.Size([496, 496])
return input_query_class, input_query_bbox, attn_mask, dn_meta
'''by lyuwenyu
'''
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import get_activation
from src.core import register
__all__ = ['HybridEncoder']
class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
super().__init__()
self.conv = nn.Conv2d(
ch_in,
ch_out,
kernel_size,
stride,
padding=(kernel_size-1)//2 if padding is None else padding,
bias=bias)
self.norm = nn.BatchNorm2d(ch_out)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
return self.act(self.norm(self.conv(x)))
class RepVggBlock(nn.Module):
def __init__(self, ch_in, ch_out, act='relu'):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
return self.act(y)
def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
self.conv.bias.data = bias
# self.__delattr__('conv1')
# self.__delattr__('conv2')
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return F.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch: ConvNormLayer):
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.norm.running_mean
running_var = branch.norm.running_var
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class CSPRepLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_blocks=3,
expansion=1.0,
bias=None,
act="silu"):
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(*[
RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
else:
self.conv3 = nn.Identity()
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)
# transformer
class TransformerEncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False):
super().__init__()
self.normalize_before = normalize_before
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = get_activation(activation)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
residual = src
if self.normalize_before:
src = self.norm1(src)
q = k = self.with_pos_embed(src, pos_embed)
src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
output = src
for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
if self.norm is not None:
output = self.norm(output)
return output
@register
class HybridEncoder(nn.Module):
def __init__(self,
in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
hidden_dim=256,
nhead=8,
dim_feedforward = 1024,
dropout=0.0,
enc_act='gelu',
use_encoder_idx=[2],
num_encoder_layers=1,
pe_temperature=10000,
expansion=1.0,
depth_mult=1.0,
act='silu',
eval_spatial_size=None):
super().__init__()
self.in_channels = in_channels
self.feat_strides = feat_strides
self.hidden_dim = hidden_dim
self.use_encoder_idx = use_encoder_idx
self.num_encoder_layers = num_encoder_layers
self.pe_temperature = pe_temperature
self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim for _ in range(len(in_channels))]
self.out_strides = feat_strides
# channel projection
self.input_proj = nn.ModuleList()
for in_channel in in_channels:
self.input_proj.append(
nn.Sequential(
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(hidden_dim)
)
)
# encoder transformer
encoder_layer = TransformerEncoderLayer(
hidden_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=enc_act)
self.encoder = nn.ModuleList([
TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
])
# top-down fpn
self.lateral_convs = nn.ModuleList()
self.fpn_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
self.fpn_blocks.append(
CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)
# bottom-up pan
self.downsample_convs = nn.ModuleList()
self.pan_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1):
self.downsample_convs.append(
ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act)
)
self.pan_blocks.append(
CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)
self._reset_parameters()
def _reset_parameters(self):
if self.eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pos_embed = self.build_2d_sincos_position_embedding(
self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
self.hidden_dim, self.pe_temperature)
setattr(self, f'pos_embed{idx}', pos_embed)
# self.register_buffer(f'pos_embed{idx}', pos_embed)
@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
'''
'''
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
assert embed_dim % 4 == 0, \
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
def forward(self, feats):
assert len(feats) == len(self.in_channels)
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
# encoder
if self.num_encoder_layers > 0:
for i, enc_ind in enumerate(self.use_encoder_idx):
h, w = proj_feats[enc_ind].shape[2:]
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1)
if self.training or self.eval_spatial_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device)
else:
pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device)
memory = self.encoder[i](src_flatten, pos_embed=pos_embed)
proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
# print([x.is_contiguous() for x in proj_feats ])
# broadcasting and fusion
inner_outs = [proj_feats[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = proj_feats[idx - 1]
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
inner_outs[0] = feat_heigh
upsample_feat = F.interpolate(feat_heigh, scale_factor=2., mode='nearest')
inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
inner_outs.insert(0, inner_out)
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsample_convs[idx](feat_low)
out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_height], dim=1))
outs.append(out)
return outs
"""
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Modules to compute the matching cost and solve the corresponding LSAP.
by lyuwenyu
"""
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from torch import nn
from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou
from src.core import register
@register
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
__share__ = ['use_focal_loss', ]
def __init__(self, weight_dict, use_focal_loss=False, alpha=0.25, gamma=2.0):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = weight_dict['cost_class']
self.cost_bbox = weight_dict['cost_bbox']
self.cost_giou = weight_dict['cost_giou']
self.use_focal_loss = use_focal_loss
self.alpha = alpha
self.gamma = gamma
assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs, targets):
""" Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
if self.use_focal_loss:
out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1))
else:
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
if self.use_focal_loss:
out_prob = out_prob[:, tgt_ids]
neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = self.alpha * ((1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -out_prob[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
"""by lyuwenyu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from src.core import register
__all__ = ['RTDETR', ]
@register
class RTDETR(nn.Module):
__inject__ = ['backbone', 'encoder', 'decoder', ]
def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None):
super().__init__()
self.backbone = backbone
self.decoder = decoder
self.encoder = encoder
self.multi_scale = multi_scale
def forward(self, x, targets=None):
if self.multi_scale and self.training:
# sz = np.random.choice(self.multi_scale)
sz = int(np.random.choice(self.multi_scale))
x = F.interpolate(x, size=[sz, sz])
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder(x, targets)
return x
def deploy(self, ):
self.eval()
for m in self.modules():
if hasattr(m, 'convert_to_deploy'):
m.convert_to_deploy()
return self
"""
reference:
https://github.com/facebookresearch/detr/blob/main/models/detr.py
by lyuwenyu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
# from torchvision.ops import box_convert, generalized_box_iou
from .box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from src.misc.dist import get_world_size, is_dist_available_and_initialized
from src.core import register
@register
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
__share__ = ['num_classes', ]
__inject__ = ['matcher', ]
def __init__(self, matcher, weight_dict, losses, alpha=0.2, gamma=2.0, eos_coef=1e-4, num_classes=80):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = eos_coef
self.register_buffer('empty_weight', empty_weight)
self.alpha = alpha
self.gamma = gamma
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True):
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
loss = F.binary_cross_entropy_with_logits(src_logits, target * 1., reduction='none')
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
return {'loss_bce': loss}
def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target = F.one_hot(target_classes, num_classes=self.num_classes+1)[..., :-1]
# ce_loss = F.binary_cross_entropy_with_logits(src_logits, target * 1., reduction="none")
# prob = F.sigmoid(src_logits) # TODO .detach()
# p_t = prob * target + (1 - prob) * (1 - target)
# alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
# loss = alpha_t * ce_loss * ((1 - p_t) ** self.gamma)
# loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
loss = torchvision.ops.sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma, reduction='none')
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
return {'loss_focal': loss}
def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
ious = torch.diag(ious).detach()
src_logits = outputs['pred_logits']
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
target_score_o[idx] = ious.to(target_score_o.dtype)
target_score = target_score_o.unsqueeze(-1) * target
pred_score = F.sigmoid(src_logits).detach()
weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
return {'loss_vfl': loss}
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
losses = {}
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(generalized_box_iou(
box_cxcywh_to_xyxy(src_boxes),
box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
def loss_masks(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
assert "pred_masks" in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
mode="bilinear", align_corners=False)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks,
'bce': self.loss_labels_bce,
'focal': self.loss_labels_focal,
'vfl': self.loss_labels_vfl,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_available_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
losses.update(l_dict)
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
# In case of cdn auxiliary losses. For rtdetr
if 'dn_aux_outputs' in outputs:
assert 'dn_meta' in outputs, ''
indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
for i, aux_outputs in enumerate(outputs['dn_aux_outputs']):
# indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
@staticmethod
def get_cdn_matched_indices(dn_meta, targets):
'''get_cdn_matched_indices
'''
dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
num_gts = [len(t['labels']) for t in targets]
device = targets[0]['labels'].device
dn_match_indices = []
for i, num_gt in enumerate(num_gts):
if num_gt > 0:
gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
gt_idx = gt_idx.tile(dn_num_group)
assert len(dn_positive_idx[i]) == len(gt_idx)
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
torch.zeros(0, dtype=torch.int64, device=device)))
return dn_match_indices
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
"""by lyuwenyu
"""
import math
import copy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from .denoising import get_contrastive_denoising_training_group
from .utils import deformable_attention_core_func, get_activation, inverse_sigmoid
from .utils import bias_init_with_prob
from src.core import register
__all__ = ['RTDETRTransformer']
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=4, num_points=4,):
"""
Multi-Scale Deformable Attention Module
"""
super(MSDeformableAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.total_points = num_heads * num_levels * num_points
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2,)
self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.ms_deformable_attn_core = deformable_attention_core_func
self._reset_parameters()
def _reset_parameters(self):
# sampling_offsets
init.constant_(self.sampling_offsets.weight, 0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
grid_init = grid_init.reshape(self.num_heads, 1, 1, 2).tile([1, self.num_levels, self.num_points, 1])
scaling = torch.arange(1, self.num_points + 1, dtype=torch.float32).reshape(1, 1, -1, 1)
grid_init *= scaling
self.sampling_offsets.bias.data[...] = grid_init.flatten()
# attention_weights
init.constant_(self.attention_weights.weight, 0)
init.constant_(self.attention_weights.bias, 0)
# proj
init.xavier_uniform_(self.value_proj.weight)
init.constant_(self.value_proj.bias, 0)
init.xavier_uniform_(self.output_proj.weight)
init.constant_(self.output_proj.bias, 0)
def forward(self,
query,
reference_points,
value,
value_spatial_shapes,
value_mask=None):
"""
Args:
query (Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, Len_q = query.shape[:2]
Len_v = value.shape[1]
value = self.value_proj(value)
if value_mask is not None:
value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
value *= value_mask
value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
sampling_offsets = self.sampling_offsets(query).reshape(
bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).reshape(
bs, Len_q, self.num_heads, self.num_levels * self.num_points)
attention_weights = F.softmax(attention_weights, dim=-1).reshape(
bs, Len_q, self.num_heads, self.num_levels, self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.tensor(value_spatial_shapes)
offset_normalizer = offset_normalizer.flip([1]).reshape(
1, 1, 1, self.num_levels, 1, 2)
sampling_locations = reference_points.reshape(
bs, Len_q, 1, self.num_levels, 1, 2
) + sampling_offsets / offset_normalizer
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] + sampling_offsets /
self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self,
d_model=256,
n_head=8,
dim_feedforward=1024,
dropout=0.,
activation="relu",
n_levels=4,
n_points=4,):
super(TransformerDecoderLayer, self).__init__()
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# cross attention
self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = getattr(F, activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
# self._reset_parameters()
# def _reset_parameters(self):
# linear_init_(self.linear1)
# linear_init_(self.linear2)
# xavier_uniform_(self.linear1.weight)
# xavier_uniform_(self.linear2.weight)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
attn_mask=None,
memory_mask=None,
query_pos_embed=None):
# self attention
q = k = self.with_pos_embed(tgt, query_pos_embed)
# if attn_mask is not None:
# attn_mask = torch.where(
# attn_mask.to(torch.bool),
# torch.zeros_like(attn_mask),
# torch.full_like(attn_mask, float('-inf'), dtype=tgt.dtype))
tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attention
tgt2 = self.cross_attn(\
self.with_pos_embed(tgt, query_pos_embed),
reference_points,
memory,
memory_spatial_shapes,
memory_mask)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# ffn
tgt2 = self.forward_ffn(tgt)
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
super(TransformerDecoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
def forward(self,
tgt,
ref_points_unact,
memory,
memory_spatial_shapes,
memory_level_start_index,
bbox_head,
score_head,
query_pos_head,
attn_mask=None,
memory_mask=None):
output = tgt
dec_out_bboxes = []
dec_out_logits = []
ref_points_detach = F.sigmoid(ref_points_unact)
for i, layer in enumerate(self.layers):
ref_points_input = ref_points_detach.unsqueeze(2)
query_pos_embed = query_pos_head(ref_points_detach)
output = layer(output, ref_points_input, memory,
memory_spatial_shapes, memory_level_start_index,
attn_mask, memory_mask, query_pos_embed)
inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
if self.training:
dec_out_logits.append(score_head[i](output))
if i == 0:
dec_out_bboxes.append(inter_ref_bbox)
else:
dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
elif i == self.eval_idx:
dec_out_logits.append(score_head[i](output))
dec_out_bboxes.append(inter_ref_bbox)
break
ref_points = inter_ref_bbox
ref_points_detach = inter_ref_bbox.detach(
) if self.training else inter_ref_bbox
return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
@register
class RTDETRTransformer(nn.Module):
__share__ = ['num_classes']
def __init__(self,
num_classes=80,
hidden_dim=256,
num_queries=300,
position_embed_type='sine',
feat_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
num_levels=3,
num_decoder_points=4,
nhead=8,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.,
activation="relu",
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=False,
eval_spatial_size=None,
eval_idx=-1,
eps=1e-2,
aux_loss=True):
super(RTDETRTransformer, self).__init__()
assert position_embed_type in ['sine', 'learned'], \
f'ValueError: position_embed_type not supported {position_embed_type}!'
assert len(feat_channels) <= num_levels
assert len(feat_strides) == len(feat_channels)
for _ in range(num_levels - len(feat_strides)):
feat_strides.append(feat_strides[-1] * 2)
self.hidden_dim = hidden_dim
self.nhead = nhead
self.feat_strides = feat_strides
self.num_levels = num_levels
self.num_classes = num_classes
self.num_queries = num_queries
self.eps = eps
self.num_decoder_layers = num_decoder_layers
self.eval_spatial_size = eval_spatial_size
self.aux_loss = aux_loss
# backbone feature projection
self._build_input_proj_layer(feat_channels)
# Transformer module
decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, num_decoder_points)
self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
# denoising part
if num_denoising > 0:
# self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim, padding_idx=num_classes-1) # TODO for load paddle weights
self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
# decoder embedding
self.learnt_init_query = learnt_init_query
if learnt_init_query:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
# encoder head
self.enc_output = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim,)
)
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
# decoder head
self.dec_score_head = nn.ModuleList([
nn.Linear(hidden_dim, num_classes)
for _ in range(num_decoder_layers)
])
self.dec_bbox_head = nn.ModuleList([
MLP(hidden_dim, hidden_dim, 4, num_layers=3)
for _ in range(num_decoder_layers)
])
# init encoder output anchors and valid_mask
if self.eval_spatial_size:
self.anchors, self.valid_mask = self._generate_anchors()
self._reset_parameters()
def _reset_parameters(self):
bias = bias_init_with_prob(0.01)
init.constant_(self.enc_score_head.bias, bias)
init.constant_(self.enc_bbox_head.layers[-1].weight, 0)
init.constant_(self.enc_bbox_head.layers[-1].bias, 0)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
init.constant_(cls_.bias, bias)
init.constant_(reg_.layers[-1].weight, 0)
init.constant_(reg_.layers[-1].bias, 0)
# linear_init_(self.enc_output[0])
init.xavier_uniform_(self.enc_output[0].weight)
if self.learnt_init_query:
init.xavier_uniform_(self.tgt_embed.weight)
init.xavier_uniform_(self.query_pos_head.layers[0].weight)
init.xavier_uniform_(self.query_pos_head.layers[1].weight)
def _build_input_proj_layer(self, feat_channels):
self.input_proj = nn.ModuleList()
for in_channels in feat_channels:
self.input_proj.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
('norm', nn.BatchNorm2d(self.hidden_dim,))])
)
)
in_channels = feat_channels[-1]
for _ in range(self.num_levels - len(feat_channels)):
self.input_proj.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)),
('norm', nn.BatchNorm2d(self.hidden_dim))])
)
)
in_channels = self.hidden_dim
def _get_encoder_input(self, feats):
# get projection features
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
if self.num_levels > len(proj_feats):
len_srcs = len(proj_feats)
for i in range(len_srcs, self.num_levels):
if i == len_srcs:
proj_feats.append(self.input_proj[i](feats[-1]))
else:
proj_feats.append(self.input_proj[i](proj_feats[-1]))
# get encoder inputs
feat_flatten = []
spatial_shapes = []
level_start_index = [0, ]
for i, feat in enumerate(proj_feats):
_, _, h, w = feat.shape
# [b, c, h, w] -> [b, h*w, c]
feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
# [num_levels, 2]
spatial_shapes.append([h, w])
# [l], start index of each level
level_start_index.append(h * w + level_start_index[-1])
# [b, l, c]
feat_flatten = torch.concat(feat_flatten, 1)
level_start_index.pop()
return (feat_flatten, spatial_shapes, level_start_index)
def _generate_anchors(self,
spatial_shapes=None,
grid_size=0.05,
dtype=torch.float32,
device='cpu'):
if spatial_shapes is None:
spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
for s in self.feat_strides
]
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
grid_y, grid_x = torch.meshgrid(\
torch.arange(end=h, dtype=dtype), \
torch.arange(end=w, dtype=dtype), indexing='ij')
grid_xy = torch.stack([grid_x, grid_y], -1)
valid_WH = torch.tensor([w, h]).to(dtype)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
anchors = torch.concat(anchors, 1).to(device)
valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
# anchors = torch.where(valid_mask, anchors, float('inf'))
# anchors[valid_mask] = torch.inf # valid_mask [1, 8400, 1]
anchors = torch.where(valid_mask, anchors, torch.inf)
return anchors, valid_mask
def _get_decoder_input(self,
memory,
spatial_shapes,
denoising_class=None,
denoising_bbox_unact=None):
bs, _, _ = memory.shape
# prepare input for decoder
if self.training or self.eval_spatial_size is None:
anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
else:
anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
# memory = torch.where(valid_mask, memory, 0)
memory = valid_mask.to(memory.dtype) * memory # TODO fix type error for onnx export
output_memory = self.enc_output(memory)
enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
enc_topk_bboxes = F.sigmoid(reference_points_unact)
if denoising_bbox_unact is not None:
reference_points_unact = torch.concat(
[denoising_bbox_unact, reference_points_unact], 1)
enc_topk_logits = enc_outputs_class.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
# extract region features
if self.learnt_init_query:
target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
else:
target = output_memory.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
target = target.detach()
if denoising_class is not None:
target = torch.concat([denoising_class, target], 1)
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
def forward(self, feats, targets=None):
# input projection and embedding
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
# prepare denoising training
if self.training and self.num_denoising > 0:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
get_contrastive_denoising_training_group(targets, \
self.num_classes,
self.num_queries,
self.denoising_class_embed,
num_denoising=self.num_denoising,
label_noise_ratio=self.label_noise_ratio,
box_noise_scale=self.box_noise_scale, )
else:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
# decoder
out_bboxes, out_logits = self.decoder(
target,
init_ref_points_unact,
memory,
spatial_shapes,
level_start_index,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
if self.training and dn_meta is not None:
dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
if self.training and self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
if self.training and dn_meta is not None:
out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
out['dn_meta'] = dn_meta
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class, outputs_coord)]
"""by lyuwenyu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from src.core import register
__all__ = ['RTDETRPostProcessor']
@register
class RTDETRPostProcessor(nn.Module):
__share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category']
def __init__(self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None:
super().__init__()
self.use_focal_loss = use_focal_loss
self.num_top_queries = num_top_queries
self.num_classes = num_classes
self.remap_mscoco_category = remap_mscoco_category
self.deploy_mode = False
def extra_repr(self) -> str:
return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'
# def forward(self, outputs, orig_target_sizes):
def forward(self, outputs, orig_target_sizes):
logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
if self.use_focal_loss:
scores = F.sigmoid(logits)
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1)
labels = index % self.num_classes
index = index // self.num_classes
boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]))
else:
scores = F.softmax(logits)[:, :, :-1]
scores, labels = scores.max(dim=-1)
if scores.shape[1] > self.num_top_queries:
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
labels = torch.gather(labels, dim=1, index=index)
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
# TODO for onnx export
if self.deploy_mode:
return labels, boxes, scores
# TODO
if self.remap_mscoco_category:
from ...data.coco import mscoco_label2category
labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
.to(boxes.device).reshape(labels.shape)
results = []
for lab, box, sco in zip(labels, boxes, scores):
result = dict(labels=lab, boxes=box, scores=sco)
results.append(result)
return results
def deploy(self, ):
self.eval()
self.deploy_mode = True
return self
@property
def iou_types(self, ):
return ('bbox', )
"""by lyuwenyu
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor:
x = x.clip(min=0., max=1.)
return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))
def deformable_attention_core_func(value, value_spatial_shapes, sampling_locations, attention_weights):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor|List): [n_levels, 2]
value_level_start_index (Tensor|List): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, _, n_head, c = value.shape
_, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
split_shape = [h * w for h, w in value_spatial_shapes]
value_list = value.split(split_shape, dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[level].flatten(2).permute(
0, 2, 1).reshape(bs * n_head, c, h, w)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].permute(
0, 2, 1, 3, 4).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape(
bs * n_head, 1, Len_q, n_levels * n_points)
output = (torch.stack(
sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).reshape(bs, n_head * c, Len_q)
return output.permute(0, 2, 1)
import math
def bias_init_with_prob(prior_prob=0.01):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-math.log((1 - prior_prob) / prior_prob))
return bias_init
def get_activation(act: str, inpace: bool=True):
'''get activation
'''
act = act.lower()
if act == 'silu':
m = nn.SiLU()
elif act == 'relu':
m = nn.ReLU()
elif act == 'leaky_relu':
m = nn.LeakyReLU()
elif act == 'silu':
m = nn.SiLU()
elif act == 'gelu':
m = nn.GELU()
elif act is None:
m = nn.Identity()
elif isinstance(act, nn.Module):
m = act
else:
raise RuntimeError('')
if hasattr(m, 'inplace'):
m.inplace = inpace
return m
Train/test script examples
- `CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master-port=8989 tools/train.py -c path/to/config &> train.log 2>&1 &`
- `-r path/to/checkpoint`
- `--amp`
- `--test-only`
Tuning script examples
- `torchrun --master_port=8844 --nproc_per_node=4 tools/train.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -t https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth`
Export script examples
- `python tools/export_onnx.py -c path/to/config -r path/to/checkpoint --check`
GPU do not release memory
- `ps aux | grep "tools/train.py" | awk '{print $2}' | xargs kill -9`
Save all logs
- Appending `&> train.log 2>&1 &` or `&> train.log 2>&1`
"""by lyuwenyu
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import argparse
import numpy as np
from src.core import YAMLConfig
import torch
import torch.nn as nn
def main(args, ):
"""main
"""
cfg = YAMLConfig(args.config, resume=args.resume)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
if 'ema' in checkpoint:
state = checkpoint['ema']['module']
else:
state = checkpoint['model']
else:
raise AttributeError('only support resume to load model.state_dict by now.')
# NOTE load train mode state -> convert to deploy mode
cfg.model.load_state_dict(state)
class Model(nn.Module):
def __init__(self, ) -> None:
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
print(self.postprocessor.deploy_mode)
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
return self.postprocessor(outputs, orig_target_sizes)
model = Model()
dynamic_axes = {
'images': {0: 'N', },
'orig_target_sizes': {0: 'N'}
}
data = torch.rand(1, 3, 640, 640)
size = torch.tensor([[640, 640]])
torch.onnx.export(
model,
(data, size),
args.file_name,
input_names=['images', 'orig_target_sizes'],
output_names=['labels', 'boxes', 'scores'],
dynamic_axes=dynamic_axes,
opset_version=16,
verbose=False
)
if args.check:
import onnx
onnx_model = onnx.load(args.file_name)
onnx.checker.check_model(onnx_model)
print('Check export onnx model done...')
if args.simplify:
import onnxsim
dynamic = True
input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)
onnx.save(onnx_model_simplify, args.file_name)
print(f'Simplify onnx model {check}...')
# import onnxruntime as ort
# from PIL import Image, ImageDraw
# from torchvision.transforms import ToTensor
# # print(onnx.helper.printable_graph(mm.graph))
# im = Image.open('./000000014439.jpg').convert('RGB')
# im = im.resize((640, 640))
# im_data = ToTensor()(im)[None]
# print(im_data.shape)
# sess = ort.InferenceSession(args.file_name)
# output = sess.run(
# # output_names=['labels', 'boxes', 'scores'],
# output_names=None,
# input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
# )
# # print(type(output))
# # print([out.shape for out in output])
# labels, boxes, scores = output
# draw = ImageDraw.Draw(im)
# thrh = 0.6
# for i in range(im_data.shape[0]):
# scr = scores[i]
# lab = labels[i][scr > thrh]
# box = boxes[i][scr > thrh]
# print(i, sum(scr > thrh))
# for b in box:
# draw.rectangle(list(b), outline='red',)
# draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
# im.save('test.jpg')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, )
parser.add_argument('--resume', '-r', type=str, )
parser.add_argument('--file-name', '-f', type=str, default='model.onnx')
parser.add_argument('--check', action='store_true', default=False,)
parser.add_argument('--simplify', action='store_true', default=False,)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment