Commit e9cee049 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1056 canceled with stages
# Copyright (c) Tencent Inc. All rights reserved.
import os.path as osp
from typing import List, Union
from mmengine.fileio import get_local_path, join_path
from mmengine.utils import is_abs
from mmdet.datasets.coco import CocoDataset
from mmyolo.registry import DATASETS
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
@DATASETS.register_module()
class YOLOv5MixedGroundingDataset(BatchShapePolicyDataset, CocoDataset):
"""Mixed grounding dataset."""
METAINFO = {
'classes': ('object',),
'palette': [(220, 20, 60)]}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
# print(len(data_list))
return data_list
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
img_path = None
img_prefix = self.data_prefix.get('img', None)
if isinstance(img_prefix, str):
img_path = osp.join(img_prefix, img_info['file_name'])
elif isinstance(img_prefix, (list, tuple)):
for prefix in img_prefix:
candidate_img_path = osp.join(prefix, img_info['file_name'])
if osp.exists(candidate_img_path):
img_path = candidate_img_path
break
assert img_path is not None, (
f'Image path {img_info["file_name"]} not found in'
f'{img_prefix}')
if self.data_prefix.get('seg', None):
seg_map_path = osp.join(
self.data_prefix['seg'],
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
else:
seg_map_path = None
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['seg_map_path'] = seg_map_path
data_info['height'] = float(img_info['height'])
data_info['width'] = float(img_info['width'])
cat2id = {}
texts = []
for ann in ann_info:
cat_name = ' '.join([img_info['caption'][t[0]:t[1]]
for t in ann['tokens_positive']])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
data_info['texts'] = texts
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0,
min(x1 + w, float(img_info['width'])) - max(x1, 0))
inter_h = max(0,
min(y1 + h, float(img_info['height'])) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
cat_name = ' '.join([img_info['caption'][t[0]:t[1]]
for t in ann['tokens_positive']])
instance['bbox_label'] = cat2id[cat_name]
if ann.get('segmentation', None):
instance['mask'] = ann['segmentation']
instances.append(instance)
# NOTE: for detection task, we set `is_detection` to 1
data_info['is_detection'] = 1
data_info['instances'] = instances
# print(data_info['texts'])
return data_info
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
if self.filter_cfg is None:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
min_size = self.filter_cfg.get('min_size', 0)
# obtain images that contain annotation
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
img_id = data_info['img_id']
width = int(data_info['width'])
height = int(data_info['height'])
if filter_empty_gt and img_id not in ids_with_ann:
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
def _join_prefix(self):
"""Join ``self.data_root`` with ``self.data_prefix`` and
``self.ann_file``.
"""
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if self.ann_file and not is_abs(self.ann_file) and self.data_root:
self.ann_file = join_path(self.data_root, self.ann_file)
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
for data_key, prefix in self.data_prefix.items():
if isinstance(prefix, (list, tuple)):
abs_prefix = []
for p in prefix:
if not is_abs(p) and self.data_root:
abs_prefix.append(join_path(self.data_root, p))
else:
abs_prefix.append(p)
self.data_prefix[data_key] = abs_prefix
elif isinstance(prefix, str):
if not is_abs(prefix) and self.data_root:
self.data_prefix[data_key] = join_path(
self.data_root, prefix)
else:
self.data_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, tuple or list,'
f'but got {type(prefix)}')
# Copyright (c) Tencent Inc. All rights reserved.
from mmdet.datasets import Objects365V1Dataset
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
@DATASETS.register_module()
class YOLOv5Objects365V1Dataset(BatchShapePolicyDataset, Objects365V1Dataset):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
# Copyright (c) Tencent Inc. All rights reserved.
from mmdet.datasets import Objects365V2Dataset
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
@DATASETS.register_module()
class YOLOv5Objects365V2Dataset(BatchShapePolicyDataset, Objects365V2Dataset):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
# Copyright (c) Tencent Inc. All rights reserved.
import copy
import json
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets import CocoDataset
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
v3det_ignore_list = [
'a00013820/26_275_28143226914_ff3a247c53_c.jpg',
'n03815615/12_1489_32968099046_be38fa580e_c.jpg',
'n04550184/19_1480_2504784164_ffa3db8844_c.jpg',
'a00008703/2_363_3576131784_dfac6fc6ce_c.jpg',
'n02814533/28_2216_30224383848_a90697f1b3_c.jpg',
'n12026476/29_186_15091304754_5c219872f7_c.jpg',
'n01956764/12_2004_50133201066_72e0d9fea5_c.jpg',
'n03785016/14_2642_518053131_d07abcb5da_c.jpg',
'a00011156/33_250_4548479728_9ce5246596_c.jpg',
'a00009461/19_152_2792869324_db95bebc84_c.jpg',
]
# # ugly code here
# with open(osp.join("data/v3det/cats.json"), 'r') as f:
# _classes = json.load(f)['classes']
@DATASETS.register_module()
class V3DetDataset(CocoDataset):
"""Objects365 v1 dataset for detection."""
METAINFO = {'classes': 'classes', 'palette': None}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(self.ann_file,
backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# 'categories' list in objects365_train.json and objects365_val.json
# is inconsistent, need sort list(or dict) before get cat_ids.
cats = self.coco.cats
sorted_cats = {i: cats[i] for i in sorted(cats)}
self.coco.cats = sorted_cats
categories = self.coco.dataset['categories']
sorted_categories = sorted(categories, key=lambda i: i['id'])
self.coco.dataset['categories'] = sorted_categories
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
file_name = osp.join(
osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
osp.split(raw_img_info['file_name'])[-1])
if file_name in v3det_ignore_list:
continue
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
@DATASETS.register_module()
class YOLOv5V3DetDataset(BatchShapePolicyDataset, V3DetDataset):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
# Copyright (c) Tencent Inc. All rights reserved.
from .optimizers import * # noqa
# Copyright (c) Tencent Inc. All rights reserved.
from .yolow_v5_optim_constructor import YOLOWv5OptimizerConstructor
__all__ = ['YOLOWv5OptimizerConstructor']
# Copyright (c) Tencent Inc. All rights reserved.
import logging
from typing import List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmengine.optim import OptimWrapper, DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class YOLOWv5OptimizerConstructor(DefaultOptimWrapperConstructor):
"""YOLO World v5 constructor for optimizers."""
def __init__(self,
optim_wrapper_cfg: dict,
paramwise_cfg: Optional[dict] = None) -> None:
super().__init__(optim_wrapper_cfg, paramwise_cfg)
self.base_total_batch_size = self.paramwise_cfg.pop(
'base_total_batch_size', 64)
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
is_dwconv = (
isinstance(module, torch.nn.Conv2d)
and module.in_channels == module.groups)
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
continue
# if the parameter match one of the custom keys, ignore other rules
for key in sorted_keys:
if key in f'{prefix}.{name}':
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
# add custom settings to param_group
for k, v in custom_keys[key].items():
param_group[k] = v
break
# NOTE: the behavious is different from MMDetection
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (
is_norm or is_dcn_module) and bias_lr_mult is not None:
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and dcn_offset_lr_mult is not None
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm and norm_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# bias lr and decay
elif (name == 'bias' and not is_dcn_module
and bias_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
# depth-wise conv
elif is_dwconv and dwconv_decay_mult is not None:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# flatten parameters except dcn offset
elif (param.ndim == 1 and not is_dcn_module
and flat_decay_mult is not None):
param_group[
'weight_decay'] = self.base_wd * flat_decay_mult
params.append(param_group)
for key, value in param_group.items():
if key == 'params':
continue
full_name = f'{prefix}.{name}' if prefix else name
print_log(
f'paramwise_options -- {full_name}:{key}={value}',
logger='current')
if mmcv_full_available():
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
is_dcn_module = isinstance(module,
(DeformConv2d, ModulatedDeformConv2d))
else:
is_dcn_module = False
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
# follow the original yolov5 implementation
if 'batch_size_per_gpu' in optimizer_cfg:
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
# No scaling if total_batch_size is less than
# base_total_batch_size, otherwise linear scaling.
total_batch_size = get_world_size() * batch_size_per_gpu
accumulate = max(
round(self.base_total_batch_size / total_batch_size), 1)
scale_factor = total_batch_size * \
accumulate / self.base_total_batch_size
if scale_factor != 1:
weight_decay = optimizer_cfg.get('weight_decay', 0)
weight_decay *= scale_factor
optimizer_cfg['weight_decay'] = weight_decay
print_log(f'Scaled weight_decay to {weight_decay}', 'current')
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
optimizer = OPTIMIZERS.build(optimizer_cfg)
else:
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper
# Copyright (c) Tencent Inc. All rights reserved.
from .backbones import * # noqa
from .layers import * # noqa
from .detectors import * # noqa
from .losses import * # noqa
from .data_preprocessors import * # noqa
from .dense_heads import * # noqa
from .necks import * # noqa
from .assigner import * # noqa
from .task_aligned_assigner import YOLOWorldSegAssigner
__all__ = ['YOLOWorldSegAssigner']
\ No newline at end of file
# Copyright (c) Tencent Inc. All rights reserved.
import torch
from torch import Tensor
from mmyolo.registry import TASK_UTILS
from mmyolo.models.task_modules.assigners import BatchTaskAlignedAssigner
from mmyolo.models.task_modules.assigners.utils import select_highest_overlaps
@TASK_UTILS.register_module()
class YOLOWorldSegAssigner(BatchTaskAlignedAssigner):
def __init__(self,
num_classes: int,
topk: int = 13,
alpha: float = 1,
beta: float = 6,
eps: float = 1e-7,
use_ciou: bool = False):
super().__init__(num_classes, topk, alpha, beta, eps, use_ciou)
@torch.no_grad()
def forward(
self,
pred_bboxes: Tensor,
pred_scores: Tensor,
priors: Tensor,
gt_labels: Tensor,
gt_bboxes: Tensor,
pad_bbox_flag: Tensor,
) -> dict:
"""Assign gt to bboxes.
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid
levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free
detector only can predict positive distance)
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bboxes,
shape(batch_size, num_priors, num_classes)
priors (Tensor): Model priors, shape (num_priors, 4)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
Returns:
assigned_result (dict) Assigned result:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned boxes,
shape(batch_size, num_priors, 4)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors, num_classes)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
"""
# (num_priors, 4) -> (num_priors, 2)
priors = priors[:, :2]
batch_size = pred_scores.size(0)
num_gt = gt_bboxes.size(1)
assigned_result = {
'assigned_labels':
gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assigned_scores':
gt_bboxes.new_full(pred_scores.shape, 0),
'fg_mask_pre_prior':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
if num_gt == 0:
return assigned_result
pos_mask, alignment_metrics, overlaps = self.get_pos_mask(
pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes,
pad_bbox_flag, batch_size, num_gt)
(assigned_gt_idxs, fg_mask_pre_prior,
pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt)
# assigned target
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior,
batch_size, num_gt)
# normalize
alignment_metrics *= pos_mask
pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0]
pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0]
norm_align_metric = (
alignment_metrics * pos_overlaps /
(pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
assigned_scores = assigned_scores * norm_align_metric
assigned_result['assigned_labels'] = assigned_labels
assigned_result['assigned_bboxes'] = assigned_bboxes
assigned_result['assigned_scores'] = assigned_scores
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
assigned_result['assigned_gt_idxs'] = assigned_gt_idxs
return assigned_result
# Copyright (c) Tencent Inc. All rights reserved.
# YOLO Multi-Modal Backbone (Vision Language)
# Vision: YOLOv8 CSPDarknet
# Language: CLIP Text Encoder (12-layer transformer)
from .mm_backbone import (
MultiModalYOLOBackbone,
HuggingVisionBackbone,
HuggingCLIPLanguageBackbone,
PseudoLanguageBackbone)
__all__ = [
'MultiModalYOLOBackbone',
'HuggingVisionBackbone',
'HuggingCLIPLanguageBackbone',
'PseudoLanguageBackbone'
]
# Copyright (c) Tencent Inc. All rights reserved.
import itertools
from typing import List, Sequence, Tuple
import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmengine.model import BaseModule
from mmyolo.registry import MODELS
from mmdet.utils import OptMultiConfig, ConfigType
from transformers import (AutoTokenizer, AutoModel, CLIPTextConfig)
from transformers import CLIPTextModelWithProjection as CLIPTP
@MODELS.register_module()
class HuggingVisionBackbone(BaseModule):
def __init__(self,
model_name: str,
out_indices: Sequence[int] = (0, 1, 2, 3),
norm_eval: bool = True,
frozen_modules: Sequence[str] = (),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.norm_eval = norm_eval
self.frozen_modules = frozen_modules
self.model = AutoModel.from_pretrained(model_name)
self._freeze_modules()
def forward(self, image: Tensor) -> Tuple[Tensor]:
encoded_dict = self.image_model(pixel_values=image,
output_hidden_states=True)
hidden_states = encoded_dict.hidden_states
img_feats = encoded_dict.get('reshaped_hidden_states', hidden_states)
img_feats = [img_feats[i] for i in self.image_out_indices]
return tuple(img_feats)
def _freeze_modules(self):
for name, module in self.model.named_modules():
for frozen_name in self.frozen_modules:
if name.startswith(frozen_name):
module.eval()
for param in module.parameters():
param.requires_grad = False
break
def train(self, mode=True):
super().train(mode)
self._freeze_modules()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@MODELS.register_module()
class HuggingCLIPLanguageBackbone(BaseModule):
def __init__(self,
model_name: str,
frozen_modules: Sequence[str] = (),
dropout: float = 0.0,
training_use_cache: bool = False,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.frozen_modules = frozen_modules
self.training_use_cache = training_use_cache
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
clip_config = CLIPTextConfig.from_pretrained(model_name,
attention_dropout=dropout)
self.model = CLIPTP.from_pretrained(model_name, config=clip_config)
self._freeze_modules()
def forward_tokenizer(self, texts):
if not hasattr(self, 'text'):
text = list(itertools.chain(*texts))
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
self.text = text.to(device=self.model.device)
return self.text
def forward(self, text: List[List[str]]) -> Tensor:
num_per_batch = [len(t) for t in text]
assert max(num_per_batch) == min(num_per_batch), (
'number of sequences not equal in batch')
text = list(itertools.chain(*text))
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
text = text.to(device=self.model.device)
txt_outputs = self.model(**text)
txt_feats = txt_outputs.text_embeds
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
txt_feats = txt_feats.reshape(-1, num_per_batch[0],
txt_feats.shape[-1])
return txt_feats
def _freeze_modules(self):
if len(self.frozen_modules) == 0:
# not freeze
return
if self.frozen_modules[0] == "all":
self.model.eval()
for _, module in self.model.named_modules():
module.eval()
for param in module.parameters():
param.requires_grad = False
return
for name, module in self.model.named_modules():
for frozen_name in self.frozen_modules:
if name.startswith(frozen_name):
module.eval()
for param in module.parameters():
param.requires_grad = False
break
def train(self, mode=True):
super().train(mode)
self._freeze_modules()
@MODELS.register_module()
class PseudoLanguageBackbone(BaseModule):
"""Pseudo Language Backbone
Args:
text_embed_path (str): path to the text embedding file
"""
def __init__(self,
text_embed_path: str = "",
test_embed_path: str = None,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
# {text:embed}
self.text_embed = torch.load(text_embed_path, map_location='cpu')
if test_embed_path is None:
self.test_embed = self.text_embed
else:
self.test_embed = torch.load(test_embed_path)
self.register_buffer("buff", torch.zeros([
1,
]))
def forward_cache(self, text: List[List[str]]) -> Tensor:
if not hasattr(self, "cache"):
self.cache = self.forward_text(text)
return self.cache
def forward(self, text: List[List[str]]) -> Tensor:
if self.training:
return self.forward_text(text)
else:
return self.forward_cache(text)
def forward_text(self, text: List[List[str]]) -> Tensor:
num_per_batch = [len(t) for t in text]
assert max(num_per_batch) == min(num_per_batch), (
'number of sequences not equal in batch')
text = list(itertools.chain(*text))
if self.training:
text_embed_dict = self.text_embed
else:
text_embed_dict = self.test_embed
text_embeds = torch.stack(
[text_embed_dict[x.split("/")[0]] for x in text])
# requires no grad and force to float
text_embeds = text_embeds.to(
self.buff.device).requires_grad_(False).float()
text_embeds = text_embeds.reshape(-1, num_per_batch[0],
text_embeds.shape[-1])
return text_embeds
@MODELS.register_module()
class MultiModalYOLOBackbone(BaseModule):
def __init__(self,
image_model: ConfigType,
text_model: ConfigType,
frozen_stages: int = -1,
with_text_model: bool = True,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg)
self.with_text_model = with_text_model
self.image_model = MODELS.build(image_model)
if self.with_text_model:
self.text_model = MODELS.build(text_model)
else:
self.text_model = None
self.frozen_stages = frozen_stages
self._freeze_stages()
def _freeze_stages(self):
"""Freeze the parameters of the specified stage so that they are no
longer updated."""
if self.frozen_stages >= 0:
for i in range(self.frozen_stages + 1):
m = getattr(self.image_model, self.image_model.layers[i])
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
"""Convert the model into training mode while keep normalization layer
frozen."""
super().train(mode)
self._freeze_stages()
def forward(self, image: Tensor,
text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]:
img_feats = self.image_model(image)
if self.with_text_model:
txt_feats = self.text_model(text)
return img_feats, txt_feats
else:
return img_feats, None
def forward_text(self, text: List[List[str]]) -> Tensor:
assert self.with_text_model, "forward_text() requires a text model"
txt_feats = self.text_model(text)
return txt_feats
def forward_image(self, image: Tensor) -> Tuple[Tensor]:
return self.image_model(image)
# Copyright (c) Tencent Inc. All rights reserved.
from .data_preprocessor import YOLOWDetDataPreprocessor
__all__ = ['YOLOWDetDataPreprocessor']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmengine.structures import BaseDataElement
from mmyolo.registry import MODELS
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
None]
@MODELS.register_module()
class YOLOWDetDataPreprocessor(DetDataPreprocessor):
"""Rewrite collate_fn to get faster training speed.
Note: It must be used together with `mmyolo.datasets.utils.yolow_collate`
"""
def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs):
super().__init__(*args, non_blocking=non_blocking, **kwargs)
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding and bgr2rgb conversion based on
``DetDataPreprocessorr``.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
if not training:
return super().forward(data, training)
data = self.cast_data(data)
inputs, data_samples = data['inputs'], data['data_samples']
assert isinstance(data['data_samples'], dict)
# TODO: Supports multi-scale training
if self._channel_conversion and inputs.shape[1] == 3:
inputs = inputs[:, [2, 1, 0], ...]
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std
if self.batch_augments is not None:
for batch_aug in self.batch_augments:
inputs, data_samples = batch_aug(inputs, data_samples)
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
data_samples_output = {
'bboxes_labels': data_samples['bboxes_labels'],
'texts': data_samples['texts'],
'img_metas': img_metas
}
if 'masks' in data_samples:
data_samples_output['masks'] = data_samples['masks']
if 'is_detection' in data_samples:
data_samples_output['is_detection'] = data_samples['is_detection']
return {'inputs': inputs, 'data_samples': data_samples_output}
# Copyright (c) Tencent Inc. All rights reserved.
from .yolo_world_head import YOLOWorldHead, YOLOWorldHeadModule, RepYOLOWorldHeadModule
from .yolo_world_seg_head import YOLOWorldSegHead, YOLOWorldSegHeadModule
__all__ = [
'YOLOWorldHead', 'YOLOWorldHeadModule', 'YOLOWorldSegHead',
'YOLOWorldSegHeadModule', 'RepYOLOWorldHeadModule'
]
# Copyright (c) Tencent Inc. All rights reserved.
import math
import copy
from typing import List, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from mmdet.structures import SampleList
from mmdet.utils import OptConfigType, InstanceList, OptInstanceList
from mmdet.models.utils import (multi_apply, unpack_gt_instances,
filter_scores_and_topk)
from mmyolo.registry import MODELS
from mmyolo.models.dense_heads import YOLOv8HeadModule, YOLOv8Head
from mmyolo.models.utils import gt_instances_preprocess
from mmcv.cnn.bricks import build_norm_layer
@MODELS.register_module()
class ContrastiveHead(BaseModule):
"""Contrastive Head for YOLO-World
compute the region-text scores according to the
similarity between image and text features
Args:
embed_dims (int): embed dim of text and image features
"""
def __init__(self,
embed_dims: int,
init_cfg: OptConfigType = None,
use_einsum: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
self.bias = nn.Parameter(torch.zeros([]))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.use_einsum = use_einsum
def forward(self, x: Tensor, w: Tensor) -> Tensor:
"""Forward function of contrastive learning."""
x = F.normalize(x, dim=1, p=2)
w = F.normalize(w, dim=-1, p=2)
if self.use_einsum:
x = torch.einsum('bchw,bkc->bkhw', x, w)
else:
batch, channel, height, width = x.shape
_, k, _ = w.shape
x = x.permute(0, 2, 3, 1) # bchw->bhwc
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
w = w.permute(0, 2, 1) # bkc->bck
x = torch.matmul(x, w)
x = x.reshape(batch, height, width, k)
x = x.permute(0, 3, 1, 2)
x = x * self.logit_scale.exp() + self.bias
return x
@MODELS.register_module()
class BNContrastiveHead(BaseModule):
""" Batch Norm Contrastive Head for YOLO-World
using batch norm instead of l2-normalization
Args:
embed_dims (int): embed dim of text and image features
norm_cfg (dict): normalization params
"""
def __init__(self,
embed_dims: int,
norm_cfg: ConfigDict,
init_cfg: OptConfigType = None,
use_einsum: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
self.bias = nn.Parameter(torch.zeros([]))
# use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
self.use_einsum = use_einsum
def forward(self, x: Tensor, w: Tensor) -> Tensor:
"""Forward function of contrastive learning."""
x = self.norm(x)
w = F.normalize(w, dim=-1, p=2)
if self.use_einsum:
x = torch.einsum('bchw,bkc->bkhw', x, w)
else:
batch, channel, height, width = x.shape
_, k, _ = w.shape
x = x.permute(0, 2, 3, 1) # bchw->bhwc
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
w = w.permute(0, 2, 1) # bkc->bck
x = torch.matmul(x, w)
x = x.reshape(batch, height, width, k)
x = x.permute(0, 3, 1, 2)
x = x * self.logit_scale.exp() + self.bias
return x
@MODELS.register_module()
class RepBNContrastiveHead(BaseModule):
""" Batch Norm Contrastive Head for YOLO-World
using batch norm instead of l2-normalization
Args:
embed_dims (int): embed dim of text and image features
norm_cfg (dict): normalization params
"""
def __init__(self,
embed_dims: int,
num_guide_embeds: int,
norm_cfg: ConfigDict,
init_cfg: OptConfigType = None) -> None:
super().__init__(init_cfg=init_cfg)
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
self.conv = nn.Conv2d(embed_dims, num_guide_embeds, kernel_size=1)
def forward(self, x: Tensor) -> Tensor:
"""Forward function of contrastive learning."""
x = self.norm(x)
x = self.conv(x)
return x
@MODELS.register_module()
class YOLOWorldHeadModule(YOLOv8HeadModule):
"""Head Module for YOLO-World
Args:
embed_dims (int): embed dim for text feautures and image features
use_bn_head (bool): use batch normalization head
"""
def __init__(self,
*args,
embed_dims: int,
use_bn_head: bool = False,
use_einsum: bool = True,
freeze_all: bool = False,
**kwargs) -> None:
self.embed_dims = embed_dims
self.use_bn_head = use_bn_head
self.use_einsum = use_einsum
self.freeze_all = freeze_all
super().__init__(*args, **kwargs)
def init_weights(self, prior_prob=0.01):
"""Initialize the weight and bias of PPYOLOE head."""
super().init_weights()
for cls_pred, cls_contrast, stride in zip(self.cls_preds,
self.cls_contrasts,
self.featmap_strides):
cls_pred[-1].bias.data[:] = 0.0 # reset bias
if hasattr(cls_contrast, 'bias'):
nn.init.constant_(
cls_contrast.bias.data,
math.log(5 / self.num_classes / (640 / stride)**2))
def _init_layers(self) -> None:
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.cls_contrasts = nn.ModuleList()
reg_out_channels = max(
(16, self.in_channels[0] // 4, self.reg_max * 4))
cls_out_channels = max(self.in_channels[0], self.num_classes)
for i in range(self.num_levels):
self.reg_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=reg_out_channels,
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=reg_out_channels,
out_channels=4 * self.reg_max,
kernel_size=1)))
self.cls_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=cls_out_channels,
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=cls_out_channels,
out_channels=self.embed_dims,
kernel_size=1)))
if self.use_bn_head:
self.cls_contrasts.append(
BNContrastiveHead(self.embed_dims,
self.norm_cfg,
use_einsum=self.use_einsum))
else:
self.cls_contrasts.append(
ContrastiveHead(self.embed_dims,
use_einsum=self.use_einsum))
proj = torch.arange(self.reg_max, dtype=torch.float)
self.register_buffer('proj', proj, persistent=False)
if self.freeze_all:
self._freeze_all()
def _freeze_all(self):
"""Freeze the model."""
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
if self.freeze_all:
self._freeze_all()
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
assert len(img_feats) == self.num_levels
txt_feats = [txt_feats for _ in range(self.num_levels)]
return multi_apply(self.forward_single, img_feats, txt_feats,
self.cls_preds, self.reg_preds, self.cls_contrasts)
def forward_single(self, img_feat: Tensor, txt_feat: Tensor,
cls_pred: nn.ModuleList, reg_pred: nn.ModuleList,
cls_contrast: nn.ModuleList) -> Tuple:
"""Forward feature of a single scale level."""
b, _, h, w = img_feat.shape
cls_embed = cls_pred(img_feat)
cls_logit = cls_contrast(cls_embed, txt_feat)
bbox_dist_preds = reg_pred(img_feat)
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds
else:
return cls_logit, bbox_preds
@MODELS.register_module()
class RepYOLOWorldHeadModule(YOLOWorldHeadModule):
def __init__(self,
*args,
embed_dims: int,
num_guide: int,
freeze_all: bool = False,
**kwargs) -> None:
super().__init__(*args,
embed_dims=embed_dims,
use_bn_head=True,
use_einsum=False,
freeze_all=freeze_all,
**kwargs)
# using rep head
cls_contrasts = []
for _ in range(self.num_levels):
cls_contrasts.append(
RepBNContrastiveHead(
embed_dims=embed_dims,
num_guide_embeds=num_guide,
norm_cfg=self.norm_cfg
)
)
self.cls_contrasts = nn.ModuleList(cls_contrasts)
def forward_single(self, img_feat: Tensor, cls_pred: nn.ModuleList,
reg_pred: nn.ModuleList,
cls_contrast: nn.ModuleList) -> Tuple:
"""Forward features from the upstream network."""
b, _, h, w = img_feat.shape
cls_embed = cls_pred(img_feat)
cls_logit = cls_contrast(cls_embed)
bbox_dist_preds = reg_pred(img_feat)
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds
else:
return cls_logit, bbox_preds
def forward(self, img_feats: Tuple[Tensor]) -> Tuple[List]:
assert len(img_feats) == self.num_levels
return multi_apply(self.forward_single, img_feats, self.cls_preds,
self.reg_preds, self.cls_contrasts)
@MODELS.register_module()
class YOLOWorldHead(YOLOv8Head):
"""YOLO-World Head
"""
def __init__(self, world_size=-1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.world_size = world_size
"""YOLO World v8 head."""
def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor,
batch_data_samples: Union[list, dict]) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
outs = self(img_feats, txt_feats)
# Fast version
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'])
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_and_predict(
self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
proposal_cfg: Optional[ConfigDict] = None
) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
"""
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
outs = self(img_feats, txt_feats)
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
cfg=proposal_cfg)
return losses, predictions
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
return self.head_module(img_feats, txt_feats)
def predict(self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
outs = self(img_feats, txt_feats)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
rescale=rescale)
return predictions
def aug_test(self,
aug_batch_feats,
aug_batch_img_metas,
rescale=False,
with_ori_nms=False,
**kwargs):
"""Test function with test time augmentation."""
raise NotImplementedError('aug_test is not implemented yet.')
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
num_imgs = len(batch_img_metas)
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = current_featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors_train = torch.cat(mlvl_priors_with_stride,
dim=0)
self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
# (bs, n, 4 * reg_max)
flatten_pred_dists = [
bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4)
for bbox_pred_org in bbox_dist_preds
]
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.stride_tensor[..., 0])
assigned_result = self.assigner(
(flatten_pred_bboxes.detach()).type(gt_bboxes.dtype),
flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train,
gt_labels, gt_bboxes, pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
assigned_scores_sum = assigned_scores.sum().clamp(min=1)
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum()
loss_cls /= assigned_scores_sum
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(assigned_scores.sum(-1),
fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos, assigned_bboxes_pos,
weight=bbox_weight) / assigned_scores_sum
# dfl loss
pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior]
assigned_ltrb = self.bbox_coder.encode(
self.flatten_priors_train[..., :2] / self.stride_tensor,
assigned_bboxes,
max_dis=self.head_module.reg_max - 1,
eps=0.01)
assigned_ltrb_pos = torch.masked_select(
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
loss_dfl = self.loss_dfl(pred_dist_pos.reshape(
-1, self.head_module.reg_max),
assigned_ltrb_pos.reshape(-1),
weight=bbox_weight.expand(-1,
4).reshape(-1),
avg_factor=assigned_scores_sum)
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
loss_dfl = flatten_pred_bboxes.sum() * 0
if self.world_size == -1:
_, world_size = get_dist_info()
else:
world_size = self.world_size
return dict(loss_cls=loss_cls * num_imgs * world_size,
loss_bbox=loss_bbox * num_imgs * world_size,
loss_dfl=loss_dfl * num_imgs * world_size)
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
if objectnesses is None:
with_objectnesses = False
else:
with_objectnesses = True
assert len(cls_scores) == len(objectnesses)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
multi_label = cfg.multi_label
multi_label &= self.num_classes > 1
cfg.multi_label = multi_label
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
# If the shape does not change, use the previous mlvl_priors
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size.numel() * self.num_base_priors, ), stride) for
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_decoded_bboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_bbox_preds, flatten_stride)
if with_objectnesses:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
else:
flatten_objectness = [None for _ in range(num_imgs)]
# 8400
# print(flatten_cls_scores.shape)
results_list = []
for (bboxes, scores, objectness,
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
flatten_objectness, batch_img_metas):
ori_shape = img_meta['ori_shape']
scale_factor = img_meta['scale_factor']
if 'pad_param' in img_meta:
pad_param = img_meta['pad_param']
else:
pad_param = None
score_thr = cfg.get('score_thr', -1)
# yolox_style does not require the following operations
if objectness is not None and score_thr > 0 and not cfg.get(
'yolox_style', False):
conf_inds = objectness > score_thr
bboxes = bboxes[conf_inds, :]
scores = scores[conf_inds, :]
objectness = objectness[conf_inds]
if objectness is not None:
# conf = obj_conf * cls_conf
scores *= objectness[:, None]
if scores.shape[0] == 0:
empty_results = InstanceData()
empty_results.bboxes = bboxes
empty_results.scores = scores[:, 0]
empty_results.labels = scores[:, 0].int()
results_list.append(empty_results)
continue
nms_pre = cfg.get('nms_pre', 100000)
if cfg.multi_label is False:
scores, labels = scores.max(1, keepdim=True)
scores, _, keep_idxs, results = filter_scores_and_topk(
scores,
score_thr,
nms_pre,
results=dict(labels=labels[:, 0]))
labels = results['labels']
else:
scores, labels, keep_idxs, _ = filter_scores_and_topk(
scores, score_thr, nms_pre)
results = InstanceData(scores=scores,
labels=labels,
bboxes=bboxes[keep_idxs])
if rescale:
if pad_param is not None:
results.bboxes -= results.bboxes.new_tensor([
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
])
results.bboxes /= results.bboxes.new_tensor(
scale_factor).repeat((1, 2))
if cfg.get('yolox_style', False):
# do not need max_per_img
cfg.max_per_img = len(results)
results = self._bbox_post_process(results=results,
cfg=cfg,
rescale=False,
with_nms=with_nms,
img_meta=img_meta)
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
results_list.append(results)
return results_list
# Copyright (c) Lin Song. All rights reserved.
import math
from typing import List, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from mmdet.structures import SampleList
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig, InstanceList)
from mmdet.models.utils import multi_apply, unpack_gt_instances
from mmyolo.models.dense_heads import YOLOv8HeadModule
from mmyolo.models.utils import gt_instances_preprocess
from mmyolo.registry import MODELS, TASK_UTILS
from mmyolo.models.dense_heads.yolov5_ins_head import (
ProtoModule, YOLOv5InsHead
)
from .yolo_world_head import ContrastiveHead, BNContrastiveHead
@MODELS.register_module()
class YOLOWorldSegHeadModule(YOLOv8HeadModule):
def __init__(self,
*args,
embed_dims: int,
proto_channels: int,
mask_channels: int,
freeze_bbox: bool = False,
freeze_all: bool = False,
use_bn_head: bool = False,
**kwargs) -> None:
self.embed_dims = embed_dims
self.proto_channels = proto_channels
self.mask_channels = mask_channels
self.freeze_bbox = freeze_bbox
self.freeze_all = freeze_all
self.use_bn_head = use_bn_head
super().__init__(*args, **kwargs)
def init_weights(self, prior_prob=0.01):
"""Initialize the weight and bias of PPYOLOE head."""
super().init_weights()
for cls_pred, cls_contrast, stride in zip(self.cls_preds,
self.cls_contrasts,
self.featmap_strides):
cls_pred[-1].bias.data[:] = 0.0 # reset bias
if hasattr(cls_contrast, 'bias'):
nn.init.constant_(
cls_contrast.bias.data,
math.log(5 / self.num_classes / (640 / stride)**2))
def _init_layers(self) -> None:
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.seg_preds = nn.ModuleList()
self.cls_contrasts = nn.ModuleList()
reg_out_channels = max(
(16, self.in_channels[0] // 4, self.reg_max * 4))
seg_out_channels = max(self.in_channels[0] // 4, self.mask_channels)
cls_out_channels = max(self.in_channels[0], self.num_classes)
bbox_norm_cfg = self.norm_cfg
bbox_norm_cfg['requires_grad'] = not self.freeze_bbox
if self.freeze_all:
self.norm_cfg['requires_grad'] = False
bbox_norm_cfg['requires_grad'] = False
for i in range(self.num_levels):
self.reg_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=bbox_norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=reg_out_channels,
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=bbox_norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=reg_out_channels,
out_channels=4 * self.reg_max,
kernel_size=1)))
self.cls_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=bbox_norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=cls_out_channels,
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=bbox_norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=cls_out_channels,
out_channels=self.embed_dims,
kernel_size=1)))
self.seg_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=seg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=seg_out_channels,
out_channels=seg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=seg_out_channels,
out_channels=self.mask_channels,
kernel_size=1)))
if self.use_bn_head:
self.cls_contrasts.append(
BNContrastiveHead(self.embed_dims, self.norm_cfg))
else:
self.cls_contrasts.append(ContrastiveHead(self.embed_dims))
proj = torch.arange(self.reg_max, dtype=torch.float)
self.register_buffer('proj', proj, persistent=False)
self.proto_pred = ProtoModule(in_channels=self.in_channels[0],
middle_channels=self.proto_channels,
mask_channels=self.mask_channels,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.freeze_bbox or self.freeze_bbox:
self._freeze_all()
def _freeze_all(self):
frozen_list = [self.cls_preds, self.reg_preds, self.cls_contrasts]
if self.freeze_all:
frozen_list.extend([self.proto_pred, self.seg_preds])
for module in frozen_list:
for m in module.modules():
if isinstance(m, _BatchNorm):
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
"""Convert the model into training mode while keep normalization layer
frozen."""
super().train(mode)
if self.freeze_bbox or self.freeze_all:
self._freeze_all()
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
assert len(img_feats) == self.num_levels
txt_feats = [txt_feats for _ in range(self.num_levels)]
mask_protos = self.proto_pred(img_feats[0])
cls_logit, bbox_preds, bbox_dist_preds, coeff_preds = multi_apply(
self.forward_single, img_feats, txt_feats, self.cls_preds,
self.reg_preds, self.cls_contrasts, self.seg_preds)
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds, coeff_preds, mask_protos
else:
return cls_logit, bbox_preds, None, coeff_preds, mask_protos
def forward_single(self, img_feat: Tensor, txt_feat: Tensor,
cls_pred: nn.ModuleList, reg_pred: nn.ModuleList,
cls_contrast: nn.ModuleList,
seg_pred: nn.ModuleList) -> Tuple:
"""Forward feature of a single scale level."""
b, _, h, w = img_feat.shape
cls_embed = cls_pred(img_feat)
cls_logit = cls_contrast(cls_embed, txt_feat)
bbox_dist_preds = reg_pred(img_feat)
coeff_pred = seg_pred(img_feat)
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds, coeff_pred
else:
return cls_logit, bbox_preds, None, coeff_pred
@MODELS.register_module()
class YOLOWorldSegHead(YOLOv5InsHead):
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0.5,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='none',
loss_weight=0.5),
loss_bbox: ConfigType = dict(type='IoULoss',
iou_mode='ciou',
bbox_format='xyxy',
reduction='sum',
loss_weight=7.5,
return_iou=False),
loss_dfl=dict(type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=1.5 / 4),
mask_overlap: bool = True,
loss_mask: ConfigType = dict(type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='none'),
loss_mask_weight=0.05,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.loss_dfl = MODELS.build(loss_dfl)
self.loss_obj = None
self.mask_overlap = mask_overlap
self.loss_mask: nn.Module = MODELS.build(loss_mask)
self.loss_mask_weight = loss_mask_weight
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
if self.train_cfg:
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
# Add common attributes to reduce calculation
self.featmap_sizes_train = None
self.num_level_priors = None
self.flatten_priors_train = None
self.stride_tensor = None
"""YOLO World head."""
def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor,
batch_data_samples: Union[list, dict]) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
outs = self(img_feats, txt_feats)
# Fast version
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
batch_data_samples['masks'],
batch_data_samples['img_metas'])
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_and_predict(
self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
proposal_cfg: Optional[ConfigDict] = None
) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
"""
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
outs = self(img_feats, txt_feats)
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
cfg=proposal_cfg)
return losses, predictions
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
return self.head_module(img_feats, txt_feats)
def predict(self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
outs = self(img_feats, txt_feats)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
rescale=rescale)
return predictions
def aug_test(self,
aug_batch_feats,
aug_batch_img_metas,
rescale=False,
with_ori_nms=False,
**kwargs):
"""Test function with test time augmentation."""
raise NotImplementedError('aug_test is not implemented yet.')
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
coeff_preds: Sequence[Tensor],
proto_preds: Tensor,
batch_gt_instances: Sequence[InstanceData],
batch_gt_masks: Sequence[Tensor],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
num_imgs = len(batch_img_metas)
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = current_featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors_train = torch.cat(mlvl_priors_with_stride,
dim=0)
self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
# (bs, n, 4 * reg_max)
flatten_pred_dists = [
bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4)
for bbox_pred_org in bbox_dist_preds
]
flatten_pred_coeffs = [
coeff_pred.permute(0, 2, 3,
1).reshape(num_imgs, -1,
self.head_module.mask_channels)
for coeff_pred in coeff_preds
]
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.stride_tensor[..., 0])
flatten_pred_coeffs = torch.cat(flatten_pred_coeffs, dim=1)
assigned_result = self.assigner(
(flatten_pred_bboxes.detach()).type(gt_bboxes.dtype),
flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train,
gt_labels, gt_bboxes, pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
assigned_gt_idxs = assigned_result['assigned_gt_idxs']
assigned_scores_sum = assigned_scores.sum().clamp(min=1)
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum()
loss_cls /= assigned_scores_sum
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(assigned_scores.sum(-1),
fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos, assigned_bboxes_pos,
weight=bbox_weight) / assigned_scores_sum
# dfl loss
pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior]
assigned_ltrb = self.bbox_coder.encode(
self.flatten_priors_train[..., :2] / self.stride_tensor,
assigned_bboxes,
max_dis=self.head_module.reg_max - 1,
eps=0.01)
assigned_ltrb_pos = torch.masked_select(
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
loss_dfl = self.loss_dfl(pred_dist_pos.reshape(
-1, self.head_module.reg_max),
assigned_ltrb_pos.reshape(-1),
weight=bbox_weight.expand(-1,
4).reshape(-1),
avg_factor=assigned_scores_sum)
_, c, mask_h, mask_w = proto_preds.shape
if batch_gt_masks.shape[-2:] != (mask_h, mask_w):
batch_gt_masks = F.interpolate(batch_gt_masks[None],
(mask_h, mask_w),
mode='nearest')[0]
loss_mask = torch.zeros(1, device=loss_dfl.device)
box_sum_flag = pad_bbox_flag.long().sum(dim=1).squeeze(1)
batch_inds = torch.zeros(num_imgs,
dtype=torch.int64,
device=assigned_gt_idxs.device)[:, None]
batch_inds[1:] = box_sum_flag.cumsum(dim=0)[:-1][..., None]
_assigned_gt_idxs = assigned_gt_idxs + batch_inds
for bs in range(num_imgs):
# 8400
bbox_match_inds = assigned_gt_idxs[bs]
mask_match_inds = _assigned_gt_idxs[bs]
bbox_match_inds = torch.masked_select(bbox_match_inds,
fg_mask_pre_prior[bs])
mask_match_inds = torch.masked_select(mask_match_inds,
fg_mask_pre_prior[bs])
# mask
mask_dim = coeff_preds[0].shape[1]
prior_mask_mask = fg_mask_pre_prior[bs].unsqueeze(-1).repeat(
[1, mask_dim])
pred_coeffs_pos = torch.masked_select(flatten_pred_coeffs[bs],
prior_mask_mask).reshape(
[-1, mask_dim])
match_boxes = gt_bboxes[bs][bbox_match_inds] / 4
normed_boxes = gt_bboxes[bs][bbox_match_inds] / 640
bbox_area = (normed_boxes[:, 2:] -
normed_boxes[:, :2]).prod(dim=1)
if not mask_match_inds.any():
continue
assert not self.mask_overlap
mask_gti = batch_gt_masks[mask_match_inds]
mask_preds = (
pred_coeffs_pos @ proto_preds[bs].view(c, -1)).view(
-1, mask_h, mask_w)
loss_mask_full = self.loss_mask(mask_preds, mask_gti)
_loss_mask = (self.crop_mask(loss_mask_full[None],
match_boxes).mean(dim=(2, 3)) /
bbox_area)
loss_mask += _loss_mask.mean()
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
loss_dfl = flatten_pred_bboxes.sum() * 0
loss_mask = flatten_pred_coeffs.sum() * 0
_, world_size = get_dist_info()
return dict(loss_cls=loss_cls * num_imgs * world_size,
loss_bbox=loss_bbox * num_imgs * world_size,
loss_dfl=loss_dfl * num_imgs * world_size,
loss_mask=loss_mask * self.loss_mask_weight * world_size)
# Copyright (c) Tencent Inc. All rights reserved.
from .yolo_world import YOLOWorldDetector, SimpleYOLOWorldDetector
__all__ = ['YOLOWorldDetector', 'SimpleYOLOWorldDetector']
# Copyright (c) Tencent Inc. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS
@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):
"""Implementation of YOLOW Series"""
def __init__(self,
*args,
mm_neck: bool = False,
num_train_classes=80,
num_test_classes=80,
**kwargs) -> None:
self.mm_neck = mm_neck
self.num_train_classes = num_train_classes
self.num_test_classes = num_test_classes
super().__init__(*args, **kwargs)
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples."""
self.bbox_head.num_classes = self.num_train_classes
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
# self.bbox_head.num_classes = self.num_test_classes
self.bbox_head.num_classes = txt_feats[0].shape[0]
results_list = self.bbox_head.predict(img_feats,
txt_feats,
batch_data_samples,
rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def reparameterize(self, texts: List[List[str]]) -> None:
# encode text embeddings into the detector
self.texts = texts
self.text_feats = self.backbone.forward_text(texts)
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
results = self.bbox_head.forward(img_feats, txt_feats)
return results
def extract_feat(
self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
"""Extract features."""
txt_feats = None
if batch_data_samples is None:
texts = self.texts
txt_feats = self.text_feats
elif isinstance(batch_data_samples,
dict) and 'texts' in batch_data_samples:
texts = batch_data_samples['texts']
elif isinstance(batch_data_samples, list) and hasattr(
batch_data_samples[0], 'texts'):
texts = [data_sample.texts for data_sample in batch_data_samples]
elif hasattr(self, 'text_feats'):
texts = self.texts
txt_feats = self.text_feats
else:
raise TypeError('batch_data_samples should be dict or list.')
if txt_feats is not None:
# forward image only
img_feats = self.backbone.forward_image(batch_inputs)
else:
img_feats, txt_feats = self.backbone(batch_inputs, texts)
if self.with_neck:
if self.mm_neck:
img_feats = self.neck(img_feats, txt_feats)
else:
img_feats = self.neck(img_feats)
return img_feats, txt_feats
@MODELS.register_module()
class SimpleYOLOWorldDetector(YOLODetector):
"""Implementation of YOLO World Series"""
def __init__(self,
*args,
mm_neck: bool = False,
num_train_classes=80,
num_test_classes=80,
prompt_dim=512,
num_prompts=80,
embedding_path='',
reparameterized=False,
freeze_prompt=False,
use_mlp_adapter=False,
**kwargs) -> None:
self.mm_neck = mm_neck
self.num_training_classes = num_train_classes
self.num_test_classes = num_test_classes
self.prompt_dim = prompt_dim
self.num_prompts = num_prompts
self.reparameterized = reparameterized
self.freeze_prompt = freeze_prompt
self.use_mlp_adapter = use_mlp_adapter
super().__init__(*args, **kwargs)
if not self.reparameterized:
if len(embedding_path) > 0:
import numpy as np
self.embeddings = torch.nn.Parameter(
torch.from_numpy(np.load(embedding_path)).float())
else:
# random init
embeddings = nn.functional.normalize(torch.randn(
(num_prompts, prompt_dim)),
dim=-1)
self.embeddings = nn.Parameter(embeddings)
if self.freeze_prompt:
self.embeddings.requires_grad = False
else:
self.embeddings.requires_grad = True
if use_mlp_adapter:
self.adapter = nn.Sequential(
nn.Linear(prompt_dim, prompt_dim * 2), nn.ReLU(True),
nn.Linear(prompt_dim * 2, prompt_dim))
else:
self.adapter = None
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples."""
self.bbox_head.num_classes = self.num_training_classes
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
if self.reparameterized:
losses = self.bbox_head.loss(img_feats, batch_data_samples)
else:
losses = self.bbox_head.loss(img_feats, txt_feats,
batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
self.bbox_head.num_classes = self.num_test_classes
if self.reparameterized:
results_list = self.bbox_head.predict(img_feats,
batch_data_samples,
rescale=rescale)
else:
results_list = self.bbox_head.predict(img_feats,
txt_feats,
batch_data_samples,
rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
if self.reparameterized:
results = self.bbox_head.forward(img_feats)
else:
results = self.bbox_head.forward(img_feats, txt_feats)
return results
def extract_feat(
self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
"""Extract features."""
# only image features
img_feats, _ = self.backbone(batch_inputs, None)
if not self.reparameterized:
# use embeddings
txt_feats = self.embeddings[None]
if self.adapter is not None:
txt_feats = self.adapter(txt_feats) + txt_feats
txt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2)
txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1)
else:
txt_feats = None
if self.with_neck:
if self.mm_neck:
img_feats = self.neck(img_feats, txt_feats)
else:
img_feats = self.neck(img_feats)
return img_feats, txt_feats
# Copyright (c) Tencent Inc. All rights reserved.
# Basic brick modules for PAFPN based on CSPLayers
from .yolo_bricks import (
CSPLayerWithTwoConv,
MaxSigmoidAttnBlock,
MaxSigmoidCSPLayerWithTwoConv,
ImagePoolingAttentionModule,
RepConvMaxSigmoidCSPLayerWithTwoConv,
RepMaxSigmoidCSPLayerWithTwoConv
)
__all__ = ['CSPLayerWithTwoConv',
'MaxSigmoidAttnBlock',
'MaxSigmoidCSPLayerWithTwoConv',
'RepConvMaxSigmoidCSPLayerWithTwoConv',
'RepMaxSigmoidCSPLayerWithTwoConv',
'ImagePoolingAttentionModule']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment