Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
YOLO-World_pytorch
Commits
e9cee049
Commit
e9cee049
authored
May 31, 2024
by
luopl
Browse files
Initial commit
parents
Pipeline
#1056
canceled with stages
Changes
166
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2509 additions
and
0 deletions
+2509
-0
yolo_world/datasets/yolov5_mixed_grounding.py
yolo_world/datasets/yolov5_mixed_grounding.py
+200
-0
yolo_world/datasets/yolov5_obj365v1.py
yolo_world/datasets/yolov5_obj365v1.py
+15
-0
yolo_world/datasets/yolov5_obj365v2.py
yolo_world/datasets/yolov5_obj365v2.py
+15
-0
yolo_world/datasets/yolov5_v3det.py
yolo_world/datasets/yolov5_v3det.py
+110
-0
yolo_world/engine/__init__.py
yolo_world/engine/__init__.py
+2
-0
yolo_world/engine/optimizers/__init__.py
yolo_world/engine/optimizers/__init__.py
+4
-0
yolo_world/engine/optimizers/yolow_v5_optim_constructor.py
yolo_world/engine/optimizers/yolow_v5_optim_constructor.py
+187
-0
yolo_world/models/__init__.py
yolo_world/models/__init__.py
+9
-0
yolo_world/models/assigner/__init__.py
yolo_world/models/assigner/__init__.py
+4
-0
yolo_world/models/assigner/task_aligned_assigner.py
yolo_world/models/assigner/task_aligned_assigner.py
+108
-0
yolo_world/models/backbones/__init__.py
yolo_world/models/backbones/__init__.py
+16
-0
yolo_world/models/backbones/mm_backbone.py
yolo_world/models/backbones/mm_backbone.py
+227
-0
yolo_world/models/data_preprocessors/__init__.py
yolo_world/models/data_preprocessors/__init__.py
+4
-0
yolo_world/models/data_preprocessors/data_preprocessor.py
yolo_world/models/data_preprocessors/data_preprocessor.py
+63
-0
yolo_world/models/dense_heads/__init__.py
yolo_world/models/dense_heads/__init__.py
+8
-0
yolo_world/models/dense_heads/yolo_world_head.py
yolo_world/models/dense_heads/yolo_world_head.py
+734
-0
yolo_world/models/dense_heads/yolo_world_seg_head.py
yolo_world/models/dense_heads/yolo_world_seg_head.py
+550
-0
yolo_world/models/detectors/__init__.py
yolo_world/models/detectors/__init__.py
+4
-0
yolo_world/models/detectors/yolo_world.py
yolo_world/models/detectors/yolo_world.py
+231
-0
yolo_world/models/layers/__init__.py
yolo_world/models/layers/__init__.py
+18
-0
No files found.
yolo_world/datasets/yolov5_mixed_grounding.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
os.path
as
osp
from
typing
import
List
,
Union
from
mmengine.fileio
import
get_local_path
,
join_path
from
mmengine.utils
import
is_abs
from
mmdet.datasets.coco
import
CocoDataset
from
mmyolo.registry
import
DATASETS
from
mmyolo.datasets.yolov5_coco
import
BatchShapePolicyDataset
@
DATASETS
.
register_module
()
class
YOLOv5MixedGroundingDataset
(
BatchShapePolicyDataset
,
CocoDataset
):
"""Mixed grounding dataset."""
METAINFO
=
{
'classes'
:
(
'object'
,),
'palette'
:
[(
220
,
20
,
60
)]}
def
load_data_list
(
self
)
->
List
[
dict
]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
# noqa: E501
with
get_local_path
(
self
.
ann_file
,
backend_args
=
self
.
backend_args
)
as
local_path
:
self
.
coco
=
self
.
COCOAPI
(
local_path
)
img_ids
=
self
.
coco
.
get_img_ids
()
data_list
=
[]
total_ann_ids
=
[]
for
img_id
in
img_ids
:
raw_img_info
=
self
.
coco
.
load_imgs
([
img_id
])[
0
]
raw_img_info
[
'img_id'
]
=
img_id
ann_ids
=
self
.
coco
.
get_ann_ids
(
img_ids
=
[
img_id
])
raw_ann_info
=
self
.
coco
.
load_anns
(
ann_ids
)
total_ann_ids
.
extend
(
ann_ids
)
parsed_data_info
=
self
.
parse_data_info
({
'raw_ann_info'
:
raw_ann_info
,
'raw_img_info'
:
raw_img_info
})
data_list
.
append
(
parsed_data_info
)
if
self
.
ANN_ID_UNIQUE
:
assert
len
(
set
(
total_ann_ids
))
==
len
(
total_ann_ids
),
f
"Annotation ids in '
{
self
.
ann_file
}
' are not unique!"
del
self
.
coco
# print(len(data_list))
return
data_list
def
parse_data_info
(
self
,
raw_data_info
:
dict
)
->
Union
[
dict
,
List
[
dict
]]:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info
=
raw_data_info
[
'raw_img_info'
]
ann_info
=
raw_data_info
[
'raw_ann_info'
]
data_info
=
{}
img_path
=
None
img_prefix
=
self
.
data_prefix
.
get
(
'img'
,
None
)
if
isinstance
(
img_prefix
,
str
):
img_path
=
osp
.
join
(
img_prefix
,
img_info
[
'file_name'
])
elif
isinstance
(
img_prefix
,
(
list
,
tuple
)):
for
prefix
in
img_prefix
:
candidate_img_path
=
osp
.
join
(
prefix
,
img_info
[
'file_name'
])
if
osp
.
exists
(
candidate_img_path
):
img_path
=
candidate_img_path
break
assert
img_path
is
not
None
,
(
f
'Image path
{
img_info
[
"file_name"
]
}
not found in'
f
'
{
img_prefix
}
'
)
if
self
.
data_prefix
.
get
(
'seg'
,
None
):
seg_map_path
=
osp
.
join
(
self
.
data_prefix
[
'seg'
],
img_info
[
'file_name'
].
rsplit
(
'.'
,
1
)[
0
]
+
self
.
seg_map_suffix
)
else
:
seg_map_path
=
None
data_info
[
'img_path'
]
=
img_path
data_info
[
'img_id'
]
=
img_info
[
'img_id'
]
data_info
[
'seg_map_path'
]
=
seg_map_path
data_info
[
'height'
]
=
float
(
img_info
[
'height'
])
data_info
[
'width'
]
=
float
(
img_info
[
'width'
])
cat2id
=
{}
texts
=
[]
for
ann
in
ann_info
:
cat_name
=
' '
.
join
([
img_info
[
'caption'
][
t
[
0
]:
t
[
1
]]
for
t
in
ann
[
'tokens_positive'
]])
if
cat_name
not
in
cat2id
:
cat2id
[
cat_name
]
=
len
(
cat2id
)
texts
.
append
([
cat_name
])
data_info
[
'texts'
]
=
texts
instances
=
[]
for
i
,
ann
in
enumerate
(
ann_info
):
instance
=
{}
if
ann
.
get
(
'ignore'
,
False
):
continue
x1
,
y1
,
w
,
h
=
ann
[
'bbox'
]
inter_w
=
max
(
0
,
min
(
x1
+
w
,
float
(
img_info
[
'width'
]))
-
max
(
x1
,
0
))
inter_h
=
max
(
0
,
min
(
y1
+
h
,
float
(
img_info
[
'height'
]))
-
max
(
y1
,
0
))
if
inter_w
*
inter_h
==
0
:
continue
if
ann
[
'area'
]
<=
0
or
w
<
1
or
h
<
1
:
continue
bbox
=
[
x1
,
y1
,
x1
+
w
,
y1
+
h
]
if
ann
.
get
(
'iscrowd'
,
False
):
instance
[
'ignore_flag'
]
=
1
else
:
instance
[
'ignore_flag'
]
=
0
instance
[
'bbox'
]
=
bbox
cat_name
=
' '
.
join
([
img_info
[
'caption'
][
t
[
0
]:
t
[
1
]]
for
t
in
ann
[
'tokens_positive'
]])
instance
[
'bbox_label'
]
=
cat2id
[
cat_name
]
if
ann
.
get
(
'segmentation'
,
None
):
instance
[
'mask'
]
=
ann
[
'segmentation'
]
instances
.
append
(
instance
)
# NOTE: for detection task, we set `is_detection` to 1
data_info
[
'is_detection'
]
=
1
data_info
[
'instances'
]
=
instances
# print(data_info['texts'])
return
data_info
def
filter_data
(
self
)
->
List
[
dict
]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if
self
.
test_mode
:
return
self
.
data_list
if
self
.
filter_cfg
is
None
:
return
self
.
data_list
filter_empty_gt
=
self
.
filter_cfg
.
get
(
'filter_empty_gt'
,
False
)
min_size
=
self
.
filter_cfg
.
get
(
'min_size'
,
0
)
# obtain images that contain annotation
ids_with_ann
=
set
(
data_info
[
'img_id'
]
for
data_info
in
self
.
data_list
)
valid_data_infos
=
[]
for
i
,
data_info
in
enumerate
(
self
.
data_list
):
img_id
=
data_info
[
'img_id'
]
width
=
int
(
data_info
[
'width'
])
height
=
int
(
data_info
[
'height'
])
if
filter_empty_gt
and
img_id
not
in
ids_with_ann
:
continue
if
min
(
width
,
height
)
>=
min_size
:
valid_data_infos
.
append
(
data_info
)
return
valid_data_infos
def
_join_prefix
(
self
):
"""Join ``self.data_root`` with ``self.data_prefix`` and
``self.ann_file``.
"""
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if
self
.
ann_file
and
not
is_abs
(
self
.
ann_file
)
and
self
.
data_root
:
self
.
ann_file
=
join_path
(
self
.
data_root
,
self
.
ann_file
)
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
for
data_key
,
prefix
in
self
.
data_prefix
.
items
():
if
isinstance
(
prefix
,
(
list
,
tuple
)):
abs_prefix
=
[]
for
p
in
prefix
:
if
not
is_abs
(
p
)
and
self
.
data_root
:
abs_prefix
.
append
(
join_path
(
self
.
data_root
,
p
))
else
:
abs_prefix
.
append
(
p
)
self
.
data_prefix
[
data_key
]
=
abs_prefix
elif
isinstance
(
prefix
,
str
):
if
not
is_abs
(
prefix
)
and
self
.
data_root
:
self
.
data_prefix
[
data_key
]
=
join_path
(
self
.
data_root
,
prefix
)
else
:
self
.
data_prefix
[
data_key
]
=
prefix
else
:
raise
TypeError
(
'prefix should be a string, tuple or list,'
f
'but got
{
type
(
prefix
)
}
'
)
yolo_world/datasets/yolov5_obj365v1.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
mmdet.datasets
import
Objects365V1Dataset
from
mmyolo.datasets.yolov5_coco
import
BatchShapePolicyDataset
from
mmyolo.registry
import
DATASETS
@
DATASETS
.
register_module
()
class
YOLOv5Objects365V1Dataset
(
BatchShapePolicyDataset
,
Objects365V1Dataset
):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
yolo_world/datasets/yolov5_obj365v2.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
mmdet.datasets
import
Objects365V2Dataset
from
mmyolo.datasets.yolov5_coco
import
BatchShapePolicyDataset
from
mmyolo.registry
import
DATASETS
@
DATASETS
.
register_module
()
class
YOLOv5Objects365V2Dataset
(
BatchShapePolicyDataset
,
Objects365V2Dataset
):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
yolo_world/datasets/yolov5_v3det.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
copy
import
json
import
os.path
as
osp
from
typing
import
List
from
mmengine.fileio
import
get_local_path
from
mmdet.datasets.api_wrappers
import
COCO
from
mmdet.datasets
import
CocoDataset
from
mmyolo.datasets.yolov5_coco
import
BatchShapePolicyDataset
from
mmyolo.registry
import
DATASETS
v3det_ignore_list
=
[
'a00013820/26_275_28143226914_ff3a247c53_c.jpg'
,
'n03815615/12_1489_32968099046_be38fa580e_c.jpg'
,
'n04550184/19_1480_2504784164_ffa3db8844_c.jpg'
,
'a00008703/2_363_3576131784_dfac6fc6ce_c.jpg'
,
'n02814533/28_2216_30224383848_a90697f1b3_c.jpg'
,
'n12026476/29_186_15091304754_5c219872f7_c.jpg'
,
'n01956764/12_2004_50133201066_72e0d9fea5_c.jpg'
,
'n03785016/14_2642_518053131_d07abcb5da_c.jpg'
,
'a00011156/33_250_4548479728_9ce5246596_c.jpg'
,
'a00009461/19_152_2792869324_db95bebc84_c.jpg'
,
]
# # ugly code here
# with open(osp.join("data/v3det/cats.json"), 'r') as f:
# _classes = json.load(f)['classes']
@
DATASETS
.
register_module
()
class
V3DetDataset
(
CocoDataset
):
"""Objects365 v1 dataset for detection."""
METAINFO
=
{
'classes'
:
'classes'
,
'palette'
:
None
}
COCOAPI
=
COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE
=
True
def
load_data_list
(
self
)
->
List
[
dict
]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
# noqa: E501
with
get_local_path
(
self
.
ann_file
,
backend_args
=
self
.
backend_args
)
as
local_path
:
self
.
coco
=
self
.
COCOAPI
(
local_path
)
# 'categories' list in objects365_train.json and objects365_val.json
# is inconsistent, need sort list(or dict) before get cat_ids.
cats
=
self
.
coco
.
cats
sorted_cats
=
{
i
:
cats
[
i
]
for
i
in
sorted
(
cats
)}
self
.
coco
.
cats
=
sorted_cats
categories
=
self
.
coco
.
dataset
[
'categories'
]
sorted_categories
=
sorted
(
categories
,
key
=
lambda
i
:
i
[
'id'
])
self
.
coco
.
dataset
[
'categories'
]
=
sorted_categories
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self
.
cat_ids
=
self
.
coco
.
get_cat_ids
(
cat_names
=
self
.
metainfo
[
'classes'
])
self
.
cat2label
=
{
cat_id
:
i
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
cat_img_map
=
copy
.
deepcopy
(
self
.
coco
.
cat_img_map
)
img_ids
=
self
.
coco
.
get_img_ids
()
data_list
=
[]
total_ann_ids
=
[]
for
img_id
in
img_ids
:
raw_img_info
=
self
.
coco
.
load_imgs
([
img_id
])[
0
]
raw_img_info
[
'img_id'
]
=
img_id
ann_ids
=
self
.
coco
.
get_ann_ids
(
img_ids
=
[
img_id
])
raw_ann_info
=
self
.
coco
.
load_anns
(
ann_ids
)
total_ann_ids
.
extend
(
ann_ids
)
file_name
=
osp
.
join
(
osp
.
split
(
osp
.
split
(
raw_img_info
[
'file_name'
])[
0
])[
-
1
],
osp
.
split
(
raw_img_info
[
'file_name'
])[
-
1
])
if
file_name
in
v3det_ignore_list
:
continue
parsed_data_info
=
self
.
parse_data_info
({
'raw_ann_info'
:
raw_ann_info
,
'raw_img_info'
:
raw_img_info
})
data_list
.
append
(
parsed_data_info
)
if
self
.
ANN_ID_UNIQUE
:
assert
len
(
set
(
total_ann_ids
))
==
len
(
total_ann_ids
),
f
"Annotation ids in '
{
self
.
ann_file
}
' are not unique!"
del
self
.
coco
return
data_list
@
DATASETS
.
register_module
()
class
YOLOv5V3DetDataset
(
BatchShapePolicyDataset
,
V3DetDataset
):
"""Dataset for YOLOv5 VOC Dataset.
We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
"""
pass
yolo_world/engine/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.optimizers
import
*
# noqa
yolo_world/engine/optimizers/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.yolow_v5_optim_constructor
import
YOLOWv5OptimizerConstructor
__all__
=
[
'YOLOWv5OptimizerConstructor'
]
yolo_world/engine/optimizers/yolow_v5_optim_constructor.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
logging
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn
import
GroupNorm
,
LayerNorm
from
mmengine.dist
import
get_world_size
from
mmengine.logging
import
print_log
from
mmengine.optim
import
OptimWrapper
,
DefaultOptimWrapperConstructor
from
mmengine.utils.dl_utils
import
mmcv_full_available
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
,
_InstanceNorm
from
mmyolo.registry
import
(
OPTIM_WRAPPER_CONSTRUCTORS
,
OPTIM_WRAPPERS
,
OPTIMIZERS
)
@
OPTIM_WRAPPER_CONSTRUCTORS
.
register_module
()
class
YOLOWv5OptimizerConstructor
(
DefaultOptimWrapperConstructor
):
"""YOLO World v5 constructor for optimizers."""
def
__init__
(
self
,
optim_wrapper_cfg
:
dict
,
paramwise_cfg
:
Optional
[
dict
]
=
None
)
->
None
:
super
().
__init__
(
optim_wrapper_cfg
,
paramwise_cfg
)
self
.
base_total_batch_size
=
self
.
paramwise_cfg
.
pop
(
'base_total_batch_size'
,
64
)
def
add_params
(
self
,
params
:
List
[
dict
],
module
:
nn
.
Module
,
prefix
:
str
=
''
,
is_dcn_module
:
Optional
[
Union
[
int
,
float
]]
=
None
)
->
None
:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
# get param-wise options
custom_keys
=
self
.
paramwise_cfg
.
get
(
'custom_keys'
,
{})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys
=
sorted
(
sorted
(
custom_keys
.
keys
()),
key
=
len
,
reverse
=
True
)
bias_lr_mult
=
self
.
paramwise_cfg
.
get
(
'bias_lr_mult'
,
None
)
bias_decay_mult
=
self
.
paramwise_cfg
.
get
(
'bias_decay_mult'
,
None
)
norm_decay_mult
=
self
.
paramwise_cfg
.
get
(
'norm_decay_mult'
,
None
)
dwconv_decay_mult
=
self
.
paramwise_cfg
.
get
(
'dwconv_decay_mult'
,
None
)
flat_decay_mult
=
self
.
paramwise_cfg
.
get
(
'flat_decay_mult'
,
None
)
bypass_duplicate
=
self
.
paramwise_cfg
.
get
(
'bypass_duplicate'
,
False
)
dcn_offset_lr_mult
=
self
.
paramwise_cfg
.
get
(
'dcn_offset_lr_mult'
,
None
)
# special rules for norm layers and depth-wise conv layers
is_norm
=
isinstance
(
module
,
(
_BatchNorm
,
_InstanceNorm
,
GroupNorm
,
LayerNorm
))
is_dwconv
=
(
isinstance
(
module
,
torch
.
nn
.
Conv2d
)
and
module
.
in_channels
==
module
.
groups
)
for
name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
param_group
=
{
'params'
:
[
param
]}
if
bypass_duplicate
and
self
.
_is_in
(
param_group
,
params
):
print_log
(
f
'
{
prefix
}
is duplicate. It is skipped since '
f
'bypass_duplicate=
{
bypass_duplicate
}
'
,
logger
=
'current'
,
level
=
logging
.
WARNING
)
continue
if
not
param
.
requires_grad
:
params
.
append
(
param_group
)
continue
# if the parameter match one of the custom keys, ignore other rules
for
key
in
sorted_keys
:
if
key
in
f
'
{
prefix
}
.
{
name
}
'
:
lr_mult
=
custom_keys
[
key
].
get
(
'lr_mult'
,
1.
)
param_group
[
'lr'
]
=
self
.
base_lr
*
lr_mult
if
self
.
base_wd
is
not
None
:
decay_mult
=
custom_keys
[
key
].
get
(
'decay_mult'
,
1.
)
param_group
[
'weight_decay'
]
=
self
.
base_wd
*
decay_mult
# add custom settings to param_group
for
k
,
v
in
custom_keys
[
key
].
items
():
param_group
[
k
]
=
v
break
# NOTE: the behavious is different from MMDetection
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if
name
==
'bias'
and
not
(
is_norm
or
is_dcn_module
)
and
bias_lr_mult
is
not
None
:
param_group
[
'lr'
]
=
self
.
base_lr
*
bias_lr_mult
if
(
prefix
.
find
(
'conv_offset'
)
!=
-
1
and
is_dcn_module
and
dcn_offset_lr_mult
is
not
None
and
isinstance
(
module
,
torch
.
nn
.
Conv2d
)):
# deal with both dcn_offset's bias & weight
param_group
[
'lr'
]
=
self
.
base_lr
*
dcn_offset_lr_mult
# apply weight decay policies
if
self
.
base_wd
is
not
None
:
# norm decay
if
is_norm
and
norm_decay_mult
is
not
None
:
param_group
[
'weight_decay'
]
=
self
.
base_wd
*
norm_decay_mult
# bias lr and decay
elif
(
name
==
'bias'
and
not
is_dcn_module
and
bias_decay_mult
is
not
None
):
param_group
[
'weight_decay'
]
=
self
.
base_wd
*
bias_decay_mult
# depth-wise conv
elif
is_dwconv
and
dwconv_decay_mult
is
not
None
:
param_group
[
'weight_decay'
]
=
self
.
base_wd
*
dwconv_decay_mult
# flatten parameters except dcn offset
elif
(
param
.
ndim
==
1
and
not
is_dcn_module
and
flat_decay_mult
is
not
None
):
param_group
[
'weight_decay'
]
=
self
.
base_wd
*
flat_decay_mult
params
.
append
(
param_group
)
for
key
,
value
in
param_group
.
items
():
if
key
==
'params'
:
continue
full_name
=
f
'
{
prefix
}
.
{
name
}
'
if
prefix
else
name
print_log
(
f
'paramwise_options --
{
full_name
}
:
{
key
}
=
{
value
}
'
,
logger
=
'current'
)
if
mmcv_full_available
():
from
mmcv.ops
import
DeformConv2d
,
ModulatedDeformConv2d
is_dcn_module
=
isinstance
(
module
,
(
DeformConv2d
,
ModulatedDeformConv2d
))
else
:
is_dcn_module
=
False
for
child_name
,
child_mod
in
module
.
named_children
():
child_prefix
=
f
'
{
prefix
}
.
{
child_name
}
'
if
prefix
else
child_name
self
.
add_params
(
params
,
child_mod
,
prefix
=
child_prefix
,
is_dcn_module
=
is_dcn_module
)
def
__call__
(
self
,
model
:
nn
.
Module
)
->
OptimWrapper
:
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
optim_wrapper_cfg
=
self
.
optim_wrapper_cfg
.
copy
()
optim_wrapper_cfg
.
setdefault
(
'type'
,
'OptimWrapper'
)
optimizer_cfg
=
self
.
optimizer_cfg
.
copy
()
# follow the original yolov5 implementation
if
'batch_size_per_gpu'
in
optimizer_cfg
:
batch_size_per_gpu
=
optimizer_cfg
.
pop
(
'batch_size_per_gpu'
)
# No scaling if total_batch_size is less than
# base_total_batch_size, otherwise linear scaling.
total_batch_size
=
get_world_size
()
*
batch_size_per_gpu
accumulate
=
max
(
round
(
self
.
base_total_batch_size
/
total_batch_size
),
1
)
scale_factor
=
total_batch_size
*
\
accumulate
/
self
.
base_total_batch_size
if
scale_factor
!=
1
:
weight_decay
=
optimizer_cfg
.
get
(
'weight_decay'
,
0
)
weight_decay
*=
scale_factor
optimizer_cfg
[
'weight_decay'
]
=
weight_decay
print_log
(
f
'Scaled weight_decay to
{
weight_decay
}
'
,
'current'
)
# if no paramwise option is specified, just use the global setting
if
not
self
.
paramwise_cfg
:
optimizer_cfg
[
'params'
]
=
model
.
parameters
()
optimizer
=
OPTIMIZERS
.
build
(
optimizer_cfg
)
else
:
# set param-wise lr and weight decay recursively
params
:
List
=
[]
self
.
add_params
(
params
,
model
)
optimizer_cfg
[
'params'
]
=
params
optimizer
=
OPTIMIZERS
.
build
(
optimizer_cfg
)
optim_wrapper
=
OPTIM_WRAPPERS
.
build
(
optim_wrapper_cfg
,
default_args
=
dict
(
optimizer
=
optimizer
))
return
optim_wrapper
yolo_world/models/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.backbones
import
*
# noqa
from
.layers
import
*
# noqa
from
.detectors
import
*
# noqa
from
.losses
import
*
# noqa
from
.data_preprocessors
import
*
# noqa
from
.dense_heads
import
*
# noqa
from
.necks
import
*
# noqa
from
.assigner
import
*
# noqa
yolo_world/models/assigner/__init__.py
0 → 100644
View file @
e9cee049
from
.task_aligned_assigner
import
YOLOWorldSegAssigner
__all__
=
[
'YOLOWorldSegAssigner'
]
\ No newline at end of file
yolo_world/models/assigner/task_aligned_assigner.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
torch
from
torch
import
Tensor
from
mmyolo.registry
import
TASK_UTILS
from
mmyolo.models.task_modules.assigners
import
BatchTaskAlignedAssigner
from
mmyolo.models.task_modules.assigners.utils
import
select_highest_overlaps
@
TASK_UTILS
.
register_module
()
class
YOLOWorldSegAssigner
(
BatchTaskAlignedAssigner
):
def
__init__
(
self
,
num_classes
:
int
,
topk
:
int
=
13
,
alpha
:
float
=
1
,
beta
:
float
=
6
,
eps
:
float
=
1e-7
,
use_ciou
:
bool
=
False
):
super
().
__init__
(
num_classes
,
topk
,
alpha
,
beta
,
eps
,
use_ciou
)
@
torch
.
no_grad
()
def
forward
(
self
,
pred_bboxes
:
Tensor
,
pred_scores
:
Tensor
,
priors
:
Tensor
,
gt_labels
:
Tensor
,
gt_bboxes
:
Tensor
,
pad_bbox_flag
:
Tensor
,
)
->
dict
:
"""Assign gt to bboxes.
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid
levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free
detector only can predict positive distance)
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bboxes,
shape(batch_size, num_priors, num_classes)
priors (Tensor): Model priors, shape (num_priors, 4)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
Returns:
assigned_result (dict) Assigned result:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned boxes,
shape(batch_size, num_priors, 4)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors, num_classes)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
"""
# (num_priors, 4) -> (num_priors, 2)
priors
=
priors
[:,
:
2
]
batch_size
=
pred_scores
.
size
(
0
)
num_gt
=
gt_bboxes
.
size
(
1
)
assigned_result
=
{
'assigned_labels'
:
gt_bboxes
.
new_full
(
pred_scores
[...,
0
].
shape
,
self
.
num_classes
),
'assigned_bboxes'
:
gt_bboxes
.
new_full
(
pred_bboxes
.
shape
,
0
),
'assigned_scores'
:
gt_bboxes
.
new_full
(
pred_scores
.
shape
,
0
),
'fg_mask_pre_prior'
:
gt_bboxes
.
new_full
(
pred_scores
[...,
0
].
shape
,
0
)
}
if
num_gt
==
0
:
return
assigned_result
pos_mask
,
alignment_metrics
,
overlaps
=
self
.
get_pos_mask
(
pred_bboxes
,
pred_scores
,
priors
,
gt_labels
,
gt_bboxes
,
pad_bbox_flag
,
batch_size
,
num_gt
)
(
assigned_gt_idxs
,
fg_mask_pre_prior
,
pos_mask
)
=
select_highest_overlaps
(
pos_mask
,
overlaps
,
num_gt
)
# assigned target
assigned_labels
,
assigned_bboxes
,
assigned_scores
=
self
.
get_targets
(
gt_labels
,
gt_bboxes
,
assigned_gt_idxs
,
fg_mask_pre_prior
,
batch_size
,
num_gt
)
# normalize
alignment_metrics
*=
pos_mask
pos_align_metrics
=
alignment_metrics
.
max
(
axis
=-
1
,
keepdim
=
True
)[
0
]
pos_overlaps
=
(
overlaps
*
pos_mask
).
max
(
axis
=-
1
,
keepdim
=
True
)[
0
]
norm_align_metric
=
(
alignment_metrics
*
pos_overlaps
/
(
pos_align_metrics
+
self
.
eps
)).
max
(
-
2
)[
0
].
unsqueeze
(
-
1
)
assigned_scores
=
assigned_scores
*
norm_align_metric
assigned_result
[
'assigned_labels'
]
=
assigned_labels
assigned_result
[
'assigned_bboxes'
]
=
assigned_bboxes
assigned_result
[
'assigned_scores'
]
=
assigned_scores
assigned_result
[
'fg_mask_pre_prior'
]
=
fg_mask_pre_prior
.
bool
()
assigned_result
[
'assigned_gt_idxs'
]
=
assigned_gt_idxs
return
assigned_result
yolo_world/models/backbones/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
# YOLO Multi-Modal Backbone (Vision Language)
# Vision: YOLOv8 CSPDarknet
# Language: CLIP Text Encoder (12-layer transformer)
from
.mm_backbone
import
(
MultiModalYOLOBackbone
,
HuggingVisionBackbone
,
HuggingCLIPLanguageBackbone
,
PseudoLanguageBackbone
)
__all__
=
[
'MultiModalYOLOBackbone'
,
'HuggingVisionBackbone'
,
'HuggingCLIPLanguageBackbone'
,
'PseudoLanguageBackbone'
]
yolo_world/models/backbones/mm_backbone.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
itertools
from
typing
import
List
,
Sequence
,
Tuple
import
torch
from
torch
import
Tensor
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmengine.model
import
BaseModule
from
mmyolo.registry
import
MODELS
from
mmdet.utils
import
OptMultiConfig
,
ConfigType
from
transformers
import
(
AutoTokenizer
,
AutoModel
,
CLIPTextConfig
)
from
transformers
import
CLIPTextModelWithProjection
as
CLIPTP
@
MODELS
.
register_module
()
class
HuggingVisionBackbone
(
BaseModule
):
def
__init__
(
self
,
model_name
:
str
,
out_indices
:
Sequence
[
int
]
=
(
0
,
1
,
2
,
3
),
norm_eval
:
bool
=
True
,
frozen_modules
:
Sequence
[
str
]
=
(),
init_cfg
:
OptMultiConfig
=
None
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
norm_eval
=
norm_eval
self
.
frozen_modules
=
frozen_modules
self
.
model
=
AutoModel
.
from_pretrained
(
model_name
)
self
.
_freeze_modules
()
def
forward
(
self
,
image
:
Tensor
)
->
Tuple
[
Tensor
]:
encoded_dict
=
self
.
image_model
(
pixel_values
=
image
,
output_hidden_states
=
True
)
hidden_states
=
encoded_dict
.
hidden_states
img_feats
=
encoded_dict
.
get
(
'reshaped_hidden_states'
,
hidden_states
)
img_feats
=
[
img_feats
[
i
]
for
i
in
self
.
image_out_indices
]
return
tuple
(
img_feats
)
def
_freeze_modules
(
self
):
for
name
,
module
in
self
.
model
.
named_modules
():
for
frozen_name
in
self
.
frozen_modules
:
if
name
.
startswith
(
frozen_name
):
module
.
eval
()
for
param
in
module
.
parameters
():
param
.
requires_grad
=
False
break
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
self
.
_freeze_modules
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
@
MODELS
.
register_module
()
class
HuggingCLIPLanguageBackbone
(
BaseModule
):
def
__init__
(
self
,
model_name
:
str
,
frozen_modules
:
Sequence
[
str
]
=
(),
dropout
:
float
=
0.0
,
training_use_cache
:
bool
=
False
,
init_cfg
:
OptMultiConfig
=
None
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
frozen_modules
=
frozen_modules
self
.
training_use_cache
=
training_use_cache
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
clip_config
=
CLIPTextConfig
.
from_pretrained
(
model_name
,
attention_dropout
=
dropout
)
self
.
model
=
CLIPTP
.
from_pretrained
(
model_name
,
config
=
clip_config
)
self
.
_freeze_modules
()
def
forward_tokenizer
(
self
,
texts
):
if
not
hasattr
(
self
,
'text'
):
text
=
list
(
itertools
.
chain
(
*
texts
))
text
=
self
.
tokenizer
(
text
=
text
,
return_tensors
=
'pt'
,
padding
=
True
)
self
.
text
=
text
.
to
(
device
=
self
.
model
.
device
)
return
self
.
text
def
forward
(
self
,
text
:
List
[
List
[
str
]])
->
Tensor
:
num_per_batch
=
[
len
(
t
)
for
t
in
text
]
assert
max
(
num_per_batch
)
==
min
(
num_per_batch
),
(
'number of sequences not equal in batch'
)
text
=
list
(
itertools
.
chain
(
*
text
))
text
=
self
.
tokenizer
(
text
=
text
,
return_tensors
=
'pt'
,
padding
=
True
)
text
=
text
.
to
(
device
=
self
.
model
.
device
)
txt_outputs
=
self
.
model
(
**
text
)
txt_feats
=
txt_outputs
.
text_embeds
txt_feats
=
txt_feats
/
txt_feats
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
txt_feats
=
txt_feats
.
reshape
(
-
1
,
num_per_batch
[
0
],
txt_feats
.
shape
[
-
1
])
return
txt_feats
def
_freeze_modules
(
self
):
if
len
(
self
.
frozen_modules
)
==
0
:
# not freeze
return
if
self
.
frozen_modules
[
0
]
==
"all"
:
self
.
model
.
eval
()
for
_
,
module
in
self
.
model
.
named_modules
():
module
.
eval
()
for
param
in
module
.
parameters
():
param
.
requires_grad
=
False
return
for
name
,
module
in
self
.
model
.
named_modules
():
for
frozen_name
in
self
.
frozen_modules
:
if
name
.
startswith
(
frozen_name
):
module
.
eval
()
for
param
in
module
.
parameters
():
param
.
requires_grad
=
False
break
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
self
.
_freeze_modules
()
@
MODELS
.
register_module
()
class
PseudoLanguageBackbone
(
BaseModule
):
"""Pseudo Language Backbone
Args:
text_embed_path (str): path to the text embedding file
"""
def
__init__
(
self
,
text_embed_path
:
str
=
""
,
test_embed_path
:
str
=
None
,
init_cfg
:
OptMultiConfig
=
None
):
super
().
__init__
(
init_cfg
)
# {text:embed}
self
.
text_embed
=
torch
.
load
(
text_embed_path
,
map_location
=
'cpu'
)
if
test_embed_path
is
None
:
self
.
test_embed
=
self
.
text_embed
else
:
self
.
test_embed
=
torch
.
load
(
test_embed_path
)
self
.
register_buffer
(
"buff"
,
torch
.
zeros
([
1
,
]))
def
forward_cache
(
self
,
text
:
List
[
List
[
str
]])
->
Tensor
:
if
not
hasattr
(
self
,
"cache"
):
self
.
cache
=
self
.
forward_text
(
text
)
return
self
.
cache
def
forward
(
self
,
text
:
List
[
List
[
str
]])
->
Tensor
:
if
self
.
training
:
return
self
.
forward_text
(
text
)
else
:
return
self
.
forward_cache
(
text
)
def
forward_text
(
self
,
text
:
List
[
List
[
str
]])
->
Tensor
:
num_per_batch
=
[
len
(
t
)
for
t
in
text
]
assert
max
(
num_per_batch
)
==
min
(
num_per_batch
),
(
'number of sequences not equal in batch'
)
text
=
list
(
itertools
.
chain
(
*
text
))
if
self
.
training
:
text_embed_dict
=
self
.
text_embed
else
:
text_embed_dict
=
self
.
test_embed
text_embeds
=
torch
.
stack
(
[
text_embed_dict
[
x
.
split
(
"/"
)[
0
]]
for
x
in
text
])
# requires no grad and force to float
text_embeds
=
text_embeds
.
to
(
self
.
buff
.
device
).
requires_grad_
(
False
).
float
()
text_embeds
=
text_embeds
.
reshape
(
-
1
,
num_per_batch
[
0
],
text_embeds
.
shape
[
-
1
])
return
text_embeds
@
MODELS
.
register_module
()
class
MultiModalYOLOBackbone
(
BaseModule
):
def
__init__
(
self
,
image_model
:
ConfigType
,
text_model
:
ConfigType
,
frozen_stages
:
int
=
-
1
,
with_text_model
:
bool
=
True
,
init_cfg
:
OptMultiConfig
=
None
)
->
None
:
super
().
__init__
(
init_cfg
)
self
.
with_text_model
=
with_text_model
self
.
image_model
=
MODELS
.
build
(
image_model
)
if
self
.
with_text_model
:
self
.
text_model
=
MODELS
.
build
(
text_model
)
else
:
self
.
text_model
=
None
self
.
frozen_stages
=
frozen_stages
self
.
_freeze_stages
()
def
_freeze_stages
(
self
):
"""Freeze the parameters of the specified stage so that they are no
longer updated."""
if
self
.
frozen_stages
>=
0
:
for
i
in
range
(
self
.
frozen_stages
+
1
):
m
=
getattr
(
self
.
image_model
,
self
.
image_model
.
layers
[
i
])
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
:
bool
=
True
):
"""Convert the model into training mode while keep normalization layer
frozen."""
super
().
train
(
mode
)
self
.
_freeze_stages
()
def
forward
(
self
,
image
:
Tensor
,
text
:
List
[
List
[
str
]])
->
Tuple
[
Tuple
[
Tensor
],
Tensor
]:
img_feats
=
self
.
image_model
(
image
)
if
self
.
with_text_model
:
txt_feats
=
self
.
text_model
(
text
)
return
img_feats
,
txt_feats
else
:
return
img_feats
,
None
def
forward_text
(
self
,
text
:
List
[
List
[
str
]])
->
Tensor
:
assert
self
.
with_text_model
,
"forward_text() requires a text model"
txt_feats
=
self
.
text_model
(
text
)
return
txt_feats
def
forward_image
(
self
,
image
:
Tensor
)
->
Tuple
[
Tensor
]:
return
self
.
image_model
(
image
)
yolo_world/models/data_preprocessors/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.data_preprocessor
import
YOLOWDetDataPreprocessor
__all__
=
[
'YOLOWDetDataPreprocessor'
]
yolo_world/models/data_preprocessors/data_preprocessor.py
0 → 100644
View file @
e9cee049
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Union
import
torch
from
mmdet.models.data_preprocessors
import
DetDataPreprocessor
from
mmengine.structures
import
BaseDataElement
from
mmyolo.registry
import
MODELS
CastData
=
Union
[
tuple
,
dict
,
BaseDataElement
,
torch
.
Tensor
,
list
,
bytes
,
str
,
None
]
@
MODELS
.
register_module
()
class
YOLOWDetDataPreprocessor
(
DetDataPreprocessor
):
"""Rewrite collate_fn to get faster training speed.
Note: It must be used together with `mmyolo.datasets.utils.yolow_collate`
"""
def
__init__
(
self
,
*
args
,
non_blocking
:
Optional
[
bool
]
=
True
,
**
kwargs
):
super
().
__init__
(
*
args
,
non_blocking
=
non_blocking
,
**
kwargs
)
def
forward
(
self
,
data
:
dict
,
training
:
bool
=
False
)
->
dict
:
"""Perform normalization, padding and bgr2rgb conversion based on
``DetDataPreprocessorr``.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
if
not
training
:
return
super
().
forward
(
data
,
training
)
data
=
self
.
cast_data
(
data
)
inputs
,
data_samples
=
data
[
'inputs'
],
data
[
'data_samples'
]
assert
isinstance
(
data
[
'data_samples'
],
dict
)
# TODO: Supports multi-scale training
if
self
.
_channel_conversion
and
inputs
.
shape
[
1
]
==
3
:
inputs
=
inputs
[:,
[
2
,
1
,
0
],
...]
if
self
.
_enable_normalize
:
inputs
=
(
inputs
-
self
.
mean
)
/
self
.
std
if
self
.
batch_augments
is
not
None
:
for
batch_aug
in
self
.
batch_augments
:
inputs
,
data_samples
=
batch_aug
(
inputs
,
data_samples
)
img_metas
=
[{
'batch_input_shape'
:
inputs
.
shape
[
2
:]}]
*
len
(
inputs
)
data_samples_output
=
{
'bboxes_labels'
:
data_samples
[
'bboxes_labels'
],
'texts'
:
data_samples
[
'texts'
],
'img_metas'
:
img_metas
}
if
'masks'
in
data_samples
:
data_samples_output
[
'masks'
]
=
data_samples
[
'masks'
]
if
'is_detection'
in
data_samples
:
data_samples_output
[
'is_detection'
]
=
data_samples
[
'is_detection'
]
return
{
'inputs'
:
inputs
,
'data_samples'
:
data_samples_output
}
yolo_world/models/dense_heads/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.yolo_world_head
import
YOLOWorldHead
,
YOLOWorldHeadModule
,
RepYOLOWorldHeadModule
from
.yolo_world_seg_head
import
YOLOWorldSegHead
,
YOLOWorldSegHeadModule
__all__
=
[
'YOLOWorldHead'
,
'YOLOWorldHeadModule'
,
'YOLOWorldSegHead'
,
'YOLOWorldSegHeadModule'
,
'RepYOLOWorldHeadModule'
]
yolo_world/models/dense_heads/yolo_world_head.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
import
math
import
copy
from
typing
import
List
,
Optional
,
Tuple
,
Union
,
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
mmcv.cnn
import
ConvModule
from
mmengine.config
import
ConfigDict
from
mmengine.model
import
BaseModule
from
torch
import
Tensor
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmengine.dist
import
get_dist_info
from
mmengine.structures
import
InstanceData
from
mmdet.structures
import
SampleList
from
mmdet.utils
import
OptConfigType
,
InstanceList
,
OptInstanceList
from
mmdet.models.utils
import
(
multi_apply
,
unpack_gt_instances
,
filter_scores_and_topk
)
from
mmyolo.registry
import
MODELS
from
mmyolo.models.dense_heads
import
YOLOv8HeadModule
,
YOLOv8Head
from
mmyolo.models.utils
import
gt_instances_preprocess
from
mmcv.cnn.bricks
import
build_norm_layer
@
MODELS
.
register_module
()
class
ContrastiveHead
(
BaseModule
):
"""Contrastive Head for YOLO-World
compute the region-text scores according to the
similarity between image and text features
Args:
embed_dims (int): embed dim of text and image features
"""
def
__init__
(
self
,
embed_dims
:
int
,
init_cfg
:
OptConfigType
=
None
,
use_einsum
:
bool
=
True
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
([]))
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
ones
([])
*
np
.
log
(
1
/
0.07
))
self
.
use_einsum
=
use_einsum
def
forward
(
self
,
x
:
Tensor
,
w
:
Tensor
)
->
Tensor
:
"""Forward function of contrastive learning."""
x
=
F
.
normalize
(
x
,
dim
=
1
,
p
=
2
)
w
=
F
.
normalize
(
w
,
dim
=-
1
,
p
=
2
)
if
self
.
use_einsum
:
x
=
torch
.
einsum
(
'bchw,bkc->bkhw'
,
x
,
w
)
else
:
batch
,
channel
,
height
,
width
=
x
.
shape
_
,
k
,
_
=
w
.
shape
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# bchw->bhwc
x
=
x
.
reshape
(
batch
,
-
1
,
channel
)
# bhwc->b(hw)c
w
=
w
.
permute
(
0
,
2
,
1
)
# bkc->bck
x
=
torch
.
matmul
(
x
,
w
)
x
=
x
.
reshape
(
batch
,
height
,
width
,
k
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
x
=
x
*
self
.
logit_scale
.
exp
()
+
self
.
bias
return
x
@
MODELS
.
register_module
()
class
BNContrastiveHead
(
BaseModule
):
""" Batch Norm Contrastive Head for YOLO-World
using batch norm instead of l2-normalization
Args:
embed_dims (int): embed dim of text and image features
norm_cfg (dict): normalization params
"""
def
__init__
(
self
,
embed_dims
:
int
,
norm_cfg
:
ConfigDict
,
init_cfg
:
OptConfigType
=
None
,
use_einsum
:
bool
=
True
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
([]))
# use -1.0 is more stable
self
.
logit_scale
=
nn
.
Parameter
(
-
1.0
*
torch
.
ones
([]))
self
.
use_einsum
=
use_einsum
def
forward
(
self
,
x
:
Tensor
,
w
:
Tensor
)
->
Tensor
:
"""Forward function of contrastive learning."""
x
=
self
.
norm
(
x
)
w
=
F
.
normalize
(
w
,
dim
=-
1
,
p
=
2
)
if
self
.
use_einsum
:
x
=
torch
.
einsum
(
'bchw,bkc->bkhw'
,
x
,
w
)
else
:
batch
,
channel
,
height
,
width
=
x
.
shape
_
,
k
,
_
=
w
.
shape
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# bchw->bhwc
x
=
x
.
reshape
(
batch
,
-
1
,
channel
)
# bhwc->b(hw)c
w
=
w
.
permute
(
0
,
2
,
1
)
# bkc->bck
x
=
torch
.
matmul
(
x
,
w
)
x
=
x
.
reshape
(
batch
,
height
,
width
,
k
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
x
=
x
*
self
.
logit_scale
.
exp
()
+
self
.
bias
return
x
@
MODELS
.
register_module
()
class
RepBNContrastiveHead
(
BaseModule
):
""" Batch Norm Contrastive Head for YOLO-World
using batch norm instead of l2-normalization
Args:
embed_dims (int): embed dim of text and image features
norm_cfg (dict): normalization params
"""
def
__init__
(
self
,
embed_dims
:
int
,
num_guide_embeds
:
int
,
norm_cfg
:
ConfigDict
,
init_cfg
:
OptConfigType
=
None
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
conv
=
nn
.
Conv2d
(
embed_dims
,
num_guide_embeds
,
kernel_size
=
1
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""Forward function of contrastive learning."""
x
=
self
.
norm
(
x
)
x
=
self
.
conv
(
x
)
return
x
@
MODELS
.
register_module
()
class
YOLOWorldHeadModule
(
YOLOv8HeadModule
):
"""Head Module for YOLO-World
Args:
embed_dims (int): embed dim for text feautures and image features
use_bn_head (bool): use batch normalization head
"""
def
__init__
(
self
,
*
args
,
embed_dims
:
int
,
use_bn_head
:
bool
=
False
,
use_einsum
:
bool
=
True
,
freeze_all
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
embed_dims
=
embed_dims
self
.
use_bn_head
=
use_bn_head
self
.
use_einsum
=
use_einsum
self
.
freeze_all
=
freeze_all
super
().
__init__
(
*
args
,
**
kwargs
)
def
init_weights
(
self
,
prior_prob
=
0.01
):
"""Initialize the weight and bias of PPYOLOE head."""
super
().
init_weights
()
for
cls_pred
,
cls_contrast
,
stride
in
zip
(
self
.
cls_preds
,
self
.
cls_contrasts
,
self
.
featmap_strides
):
cls_pred
[
-
1
].
bias
.
data
[:]
=
0.0
# reset bias
if
hasattr
(
cls_contrast
,
'bias'
):
nn
.
init
.
constant_
(
cls_contrast
.
bias
.
data
,
math
.
log
(
5
/
self
.
num_classes
/
(
640
/
stride
)
**
2
))
def
_init_layers
(
self
)
->
None
:
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self
.
cls_preds
=
nn
.
ModuleList
()
self
.
reg_preds
=
nn
.
ModuleList
()
self
.
cls_contrasts
=
nn
.
ModuleList
()
reg_out_channels
=
max
(
(
16
,
self
.
in_channels
[
0
]
//
4
,
self
.
reg_max
*
4
))
cls_out_channels
=
max
(
self
.
in_channels
[
0
],
self
.
num_classes
)
for
i
in
range
(
self
.
num_levels
):
self
.
reg_preds
.
append
(
nn
.
Sequential
(
ConvModule
(
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
reg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
in_channels
=
reg_out_channels
,
out_channels
=
reg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
Conv2d
(
in_channels
=
reg_out_channels
,
out_channels
=
4
*
self
.
reg_max
,
kernel_size
=
1
)))
self
.
cls_preds
.
append
(
nn
.
Sequential
(
ConvModule
(
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
cls_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
in_channels
=
cls_out_channels
,
out_channels
=
cls_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
Conv2d
(
in_channels
=
cls_out_channels
,
out_channels
=
self
.
embed_dims
,
kernel_size
=
1
)))
if
self
.
use_bn_head
:
self
.
cls_contrasts
.
append
(
BNContrastiveHead
(
self
.
embed_dims
,
self
.
norm_cfg
,
use_einsum
=
self
.
use_einsum
))
else
:
self
.
cls_contrasts
.
append
(
ContrastiveHead
(
self
.
embed_dims
,
use_einsum
=
self
.
use_einsum
))
proj
=
torch
.
arange
(
self
.
reg_max
,
dtype
=
torch
.
float
)
self
.
register_buffer
(
'proj'
,
proj
,
persistent
=
False
)
if
self
.
freeze_all
:
self
.
_freeze_all
()
def
_freeze_all
(
self
):
"""Freeze the model."""
for
m
in
self
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
if
self
.
freeze_all
:
self
.
_freeze_all
()
def
forward
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
)
->
Tuple
[
List
]:
"""Forward features from the upstream network."""
assert
len
(
img_feats
)
==
self
.
num_levels
txt_feats
=
[
txt_feats
for
_
in
range
(
self
.
num_levels
)]
return
multi_apply
(
self
.
forward_single
,
img_feats
,
txt_feats
,
self
.
cls_preds
,
self
.
reg_preds
,
self
.
cls_contrasts
)
def
forward_single
(
self
,
img_feat
:
Tensor
,
txt_feat
:
Tensor
,
cls_pred
:
nn
.
ModuleList
,
reg_pred
:
nn
.
ModuleList
,
cls_contrast
:
nn
.
ModuleList
)
->
Tuple
:
"""Forward feature of a single scale level."""
b
,
_
,
h
,
w
=
img_feat
.
shape
cls_embed
=
cls_pred
(
img_feat
)
cls_logit
=
cls_contrast
(
cls_embed
,
txt_feat
)
bbox_dist_preds
=
reg_pred
(
img_feat
)
if
self
.
reg_max
>
1
:
bbox_dist_preds
=
bbox_dist_preds
.
reshape
(
[
-
1
,
4
,
self
.
reg_max
,
h
*
w
]).
permute
(
0
,
3
,
1
,
2
)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds
=
bbox_dist_preds
.
softmax
(
3
).
matmul
(
self
.
proj
.
view
([
-
1
,
1
])).
squeeze
(
-
1
)
bbox_preds
=
bbox_preds
.
transpose
(
1
,
2
).
reshape
(
b
,
-
1
,
h
,
w
)
else
:
bbox_preds
=
bbox_dist_preds
if
self
.
training
:
return
cls_logit
,
bbox_preds
,
bbox_dist_preds
else
:
return
cls_logit
,
bbox_preds
@
MODELS
.
register_module
()
class
RepYOLOWorldHeadModule
(
YOLOWorldHeadModule
):
def
__init__
(
self
,
*
args
,
embed_dims
:
int
,
num_guide
:
int
,
freeze_all
:
bool
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
embed_dims
=
embed_dims
,
use_bn_head
=
True
,
use_einsum
=
False
,
freeze_all
=
freeze_all
,
**
kwargs
)
# using rep head
cls_contrasts
=
[]
for
_
in
range
(
self
.
num_levels
):
cls_contrasts
.
append
(
RepBNContrastiveHead
(
embed_dims
=
embed_dims
,
num_guide_embeds
=
num_guide
,
norm_cfg
=
self
.
norm_cfg
)
)
self
.
cls_contrasts
=
nn
.
ModuleList
(
cls_contrasts
)
def
forward_single
(
self
,
img_feat
:
Tensor
,
cls_pred
:
nn
.
ModuleList
,
reg_pred
:
nn
.
ModuleList
,
cls_contrast
:
nn
.
ModuleList
)
->
Tuple
:
"""Forward features from the upstream network."""
b
,
_
,
h
,
w
=
img_feat
.
shape
cls_embed
=
cls_pred
(
img_feat
)
cls_logit
=
cls_contrast
(
cls_embed
)
bbox_dist_preds
=
reg_pred
(
img_feat
)
if
self
.
reg_max
>
1
:
bbox_dist_preds
=
bbox_dist_preds
.
reshape
(
[
-
1
,
4
,
self
.
reg_max
,
h
*
w
]).
permute
(
0
,
3
,
1
,
2
)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds
=
bbox_dist_preds
.
softmax
(
3
).
matmul
(
self
.
proj
.
view
([
-
1
,
1
])).
squeeze
(
-
1
)
bbox_preds
=
bbox_preds
.
transpose
(
1
,
2
).
reshape
(
b
,
-
1
,
h
,
w
)
else
:
bbox_preds
=
bbox_dist_preds
if
self
.
training
:
return
cls_logit
,
bbox_preds
,
bbox_dist_preds
else
:
return
cls_logit
,
bbox_preds
def
forward
(
self
,
img_feats
:
Tuple
[
Tensor
])
->
Tuple
[
List
]:
assert
len
(
img_feats
)
==
self
.
num_levels
return
multi_apply
(
self
.
forward_single
,
img_feats
,
self
.
cls_preds
,
self
.
reg_preds
,
self
.
cls_contrasts
)
@
MODELS
.
register_module
()
class
YOLOWorldHead
(
YOLOv8Head
):
"""YOLO-World Head
"""
def
__init__
(
self
,
world_size
=-
1
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
world_size
=
world_size
"""YOLO World v8 head."""
def
loss
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
Union
[
list
,
dict
])
->
dict
:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
outs
=
self
(
img_feats
,
txt_feats
)
# Fast version
loss_inputs
=
outs
+
(
batch_data_samples
[
'bboxes_labels'
],
batch_data_samples
[
'img_metas'
])
losses
=
self
.
loss_by_feat
(
*
loss_inputs
)
return
losses
def
loss_and_predict
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
SampleList
,
proposal_cfg
:
Optional
[
ConfigDict
]
=
None
)
->
Tuple
[
dict
,
InstanceList
]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
"""
outputs
=
unpack_gt_instances
(
batch_data_samples
)
(
batch_gt_instances
,
batch_gt_instances_ignore
,
batch_img_metas
)
=
outputs
outs
=
self
(
img_feats
,
txt_feats
)
loss_inputs
=
outs
+
(
batch_gt_instances
,
batch_img_metas
,
batch_gt_instances_ignore
)
losses
=
self
.
loss_by_feat
(
*
loss_inputs
)
predictions
=
self
.
predict_by_feat
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
cfg
=
proposal_cfg
)
return
losses
,
predictions
def
forward
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
)
->
Tuple
[
List
]:
"""Forward features from the upstream network."""
return
self
.
head_module
(
img_feats
,
txt_feats
)
def
predict
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
SampleList
,
rescale
:
bool
=
False
)
->
InstanceList
:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
"""
batch_img_metas
=
[
data_samples
.
metainfo
for
data_samples
in
batch_data_samples
]
outs
=
self
(
img_feats
,
txt_feats
)
predictions
=
self
.
predict_by_feat
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
rescale
=
rescale
)
return
predictions
def
aug_test
(
self
,
aug_batch_feats
,
aug_batch_img_metas
,
rescale
=
False
,
with_ori_nms
=
False
,
**
kwargs
):
"""Test function with test time augmentation."""
raise
NotImplementedError
(
'aug_test is not implemented yet.'
)
def
loss_by_feat
(
self
,
cls_scores
:
Sequence
[
Tensor
],
bbox_preds
:
Sequence
[
Tensor
],
bbox_dist_preds
:
Sequence
[
Tensor
],
batch_gt_instances
:
Sequence
[
InstanceData
],
batch_img_metas
:
Sequence
[
dict
],
batch_gt_instances_ignore
:
OptInstanceList
=
None
)
->
dict
:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
num_imgs
=
len
(
batch_img_metas
)
current_featmap_sizes
=
[
cls_score
.
shape
[
2
:]
for
cls_score
in
cls_scores
]
# If the shape does not equal, generate new one
if
current_featmap_sizes
!=
self
.
featmap_sizes_train
:
self
.
featmap_sizes_train
=
current_featmap_sizes
mlvl_priors_with_stride
=
self
.
prior_generator
.
grid_priors
(
self
.
featmap_sizes_train
,
dtype
=
cls_scores
[
0
].
dtype
,
device
=
cls_scores
[
0
].
device
,
with_stride
=
True
)
self
.
num_level_priors
=
[
len
(
n
)
for
n
in
mlvl_priors_with_stride
]
self
.
flatten_priors_train
=
torch
.
cat
(
mlvl_priors_with_stride
,
dim
=
0
)
self
.
stride_tensor
=
self
.
flatten_priors_train
[...,
[
2
]]
# gt info
gt_info
=
gt_instances_preprocess
(
batch_gt_instances
,
num_imgs
)
gt_labels
=
gt_info
[:,
:,
:
1
]
gt_bboxes
=
gt_info
[:,
:,
1
:]
# xyxy
pad_bbox_flag
=
(
gt_bboxes
.
sum
(
-
1
,
keepdim
=
True
)
>
0
).
float
()
# pred info
flatten_cls_preds
=
[
cls_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
self
.
num_classes
)
for
cls_pred
in
cls_scores
]
flatten_pred_bboxes
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
# (bs, n, 4 * reg_max)
flatten_pred_dists
=
[
bbox_pred_org
.
reshape
(
num_imgs
,
-
1
,
self
.
head_module
.
reg_max
*
4
)
for
bbox_pred_org
in
bbox_dist_preds
]
flatten_dist_preds
=
torch
.
cat
(
flatten_pred_dists
,
dim
=
1
)
flatten_cls_preds
=
torch
.
cat
(
flatten_cls_preds
,
dim
=
1
)
flatten_pred_bboxes
=
torch
.
cat
(
flatten_pred_bboxes
,
dim
=
1
)
flatten_pred_bboxes
=
self
.
bbox_coder
.
decode
(
self
.
flatten_priors_train
[...,
:
2
],
flatten_pred_bboxes
,
self
.
stride_tensor
[...,
0
])
assigned_result
=
self
.
assigner
(
(
flatten_pred_bboxes
.
detach
()).
type
(
gt_bboxes
.
dtype
),
flatten_cls_preds
.
detach
().
sigmoid
(),
self
.
flatten_priors_train
,
gt_labels
,
gt_bboxes
,
pad_bbox_flag
)
assigned_bboxes
=
assigned_result
[
'assigned_bboxes'
]
assigned_scores
=
assigned_result
[
'assigned_scores'
]
fg_mask_pre_prior
=
assigned_result
[
'fg_mask_pre_prior'
]
assigned_scores_sum
=
assigned_scores
.
sum
().
clamp
(
min
=
1
)
loss_cls
=
self
.
loss_cls
(
flatten_cls_preds
,
assigned_scores
).
sum
()
loss_cls
/=
assigned_scores_sum
# rescale bbox
assigned_bboxes
/=
self
.
stride_tensor
flatten_pred_bboxes
/=
self
.
stride_tensor
# select positive samples mask
num_pos
=
fg_mask_pre_prior
.
sum
()
if
num_pos
>
0
:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask
=
fg_mask_pre_prior
.
unsqueeze
(
-
1
).
repeat
([
1
,
1
,
4
])
pred_bboxes_pos
=
torch
.
masked_select
(
flatten_pred_bboxes
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
assigned_bboxes_pos
=
torch
.
masked_select
(
assigned_bboxes
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
bbox_weight
=
torch
.
masked_select
(
assigned_scores
.
sum
(
-
1
),
fg_mask_pre_prior
).
unsqueeze
(
-
1
)
loss_bbox
=
self
.
loss_bbox
(
pred_bboxes_pos
,
assigned_bboxes_pos
,
weight
=
bbox_weight
)
/
assigned_scores_sum
# dfl loss
pred_dist_pos
=
flatten_dist_preds
[
fg_mask_pre_prior
]
assigned_ltrb
=
self
.
bbox_coder
.
encode
(
self
.
flatten_priors_train
[...,
:
2
]
/
self
.
stride_tensor
,
assigned_bboxes
,
max_dis
=
self
.
head_module
.
reg_max
-
1
,
eps
=
0.01
)
assigned_ltrb_pos
=
torch
.
masked_select
(
assigned_ltrb
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
loss_dfl
=
self
.
loss_dfl
(
pred_dist_pos
.
reshape
(
-
1
,
self
.
head_module
.
reg_max
),
assigned_ltrb_pos
.
reshape
(
-
1
),
weight
=
bbox_weight
.
expand
(
-
1
,
4
).
reshape
(
-
1
),
avg_factor
=
assigned_scores_sum
)
else
:
loss_bbox
=
flatten_pred_bboxes
.
sum
()
*
0
loss_dfl
=
flatten_pred_bboxes
.
sum
()
*
0
if
self
.
world_size
==
-
1
:
_
,
world_size
=
get_dist_info
()
else
:
world_size
=
self
.
world_size
return
dict
(
loss_cls
=
loss_cls
*
num_imgs
*
world_size
,
loss_bbox
=
loss_bbox
*
num_imgs
*
world_size
,
loss_dfl
=
loss_dfl
*
num_imgs
*
world_size
)
def
predict_by_feat
(
self
,
cls_scores
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
objectnesses
:
Optional
[
List
[
Tensor
]]
=
None
,
batch_img_metas
:
Optional
[
List
[
dict
]]
=
None
,
cfg
:
Optional
[
ConfigDict
]
=
None
,
rescale
:
bool
=
True
,
with_nms
:
bool
=
True
)
->
List
[
InstanceData
]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
if
objectnesses
is
None
:
with_objectnesses
=
False
else
:
with_objectnesses
=
True
assert
len
(
cls_scores
)
==
len
(
objectnesses
)
cfg
=
self
.
test_cfg
if
cfg
is
None
else
cfg
cfg
=
copy
.
deepcopy
(
cfg
)
multi_label
=
cfg
.
multi_label
multi_label
&=
self
.
num_classes
>
1
cfg
.
multi_label
=
multi_label
num_imgs
=
len
(
batch_img_metas
)
featmap_sizes
=
[
cls_score
.
shape
[
2
:]
for
cls_score
in
cls_scores
]
# If the shape does not change, use the previous mlvl_priors
if
featmap_sizes
!=
self
.
featmap_sizes
:
self
.
mlvl_priors
=
self
.
prior_generator
.
grid_priors
(
featmap_sizes
,
dtype
=
cls_scores
[
0
].
dtype
,
device
=
cls_scores
[
0
].
device
)
self
.
featmap_sizes
=
featmap_sizes
flatten_priors
=
torch
.
cat
(
self
.
mlvl_priors
)
mlvl_strides
=
[
flatten_priors
.
new_full
(
(
featmap_size
.
numel
()
*
self
.
num_base_priors
,
),
stride
)
for
featmap_size
,
stride
in
zip
(
featmap_sizes
,
self
.
featmap_strides
)
]
flatten_stride
=
torch
.
cat
(
mlvl_strides
)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores
=
[
cls_score
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
self
.
num_classes
)
for
cls_score
in
cls_scores
]
flatten_bbox_preds
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
flatten_cls_scores
=
torch
.
cat
(
flatten_cls_scores
,
dim
=
1
).
sigmoid
()
flatten_bbox_preds
=
torch
.
cat
(
flatten_bbox_preds
,
dim
=
1
)
flatten_decoded_bboxes
=
self
.
bbox_coder
.
decode
(
flatten_priors
[
None
],
flatten_bbox_preds
,
flatten_stride
)
if
with_objectnesses
:
flatten_objectness
=
[
objectness
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
)
for
objectness
in
objectnesses
]
flatten_objectness
=
torch
.
cat
(
flatten_objectness
,
dim
=
1
).
sigmoid
()
else
:
flatten_objectness
=
[
None
for
_
in
range
(
num_imgs
)]
# 8400
# print(flatten_cls_scores.shape)
results_list
=
[]
for
(
bboxes
,
scores
,
objectness
,
img_meta
)
in
zip
(
flatten_decoded_bboxes
,
flatten_cls_scores
,
flatten_objectness
,
batch_img_metas
):
ori_shape
=
img_meta
[
'ori_shape'
]
scale_factor
=
img_meta
[
'scale_factor'
]
if
'pad_param'
in
img_meta
:
pad_param
=
img_meta
[
'pad_param'
]
else
:
pad_param
=
None
score_thr
=
cfg
.
get
(
'score_thr'
,
-
1
)
# yolox_style does not require the following operations
if
objectness
is
not
None
and
score_thr
>
0
and
not
cfg
.
get
(
'yolox_style'
,
False
):
conf_inds
=
objectness
>
score_thr
bboxes
=
bboxes
[
conf_inds
,
:]
scores
=
scores
[
conf_inds
,
:]
objectness
=
objectness
[
conf_inds
]
if
objectness
is
not
None
:
# conf = obj_conf * cls_conf
scores
*=
objectness
[:,
None
]
if
scores
.
shape
[
0
]
==
0
:
empty_results
=
InstanceData
()
empty_results
.
bboxes
=
bboxes
empty_results
.
scores
=
scores
[:,
0
]
empty_results
.
labels
=
scores
[:,
0
].
int
()
results_list
.
append
(
empty_results
)
continue
nms_pre
=
cfg
.
get
(
'nms_pre'
,
100000
)
if
cfg
.
multi_label
is
False
:
scores
,
labels
=
scores
.
max
(
1
,
keepdim
=
True
)
scores
,
_
,
keep_idxs
,
results
=
filter_scores_and_topk
(
scores
,
score_thr
,
nms_pre
,
results
=
dict
(
labels
=
labels
[:,
0
]))
labels
=
results
[
'labels'
]
else
:
scores
,
labels
,
keep_idxs
,
_
=
filter_scores_and_topk
(
scores
,
score_thr
,
nms_pre
)
results
=
InstanceData
(
scores
=
scores
,
labels
=
labels
,
bboxes
=
bboxes
[
keep_idxs
])
if
rescale
:
if
pad_param
is
not
None
:
results
.
bboxes
-=
results
.
bboxes
.
new_tensor
([
pad_param
[
2
],
pad_param
[
0
],
pad_param
[
2
],
pad_param
[
0
]
])
results
.
bboxes
/=
results
.
bboxes
.
new_tensor
(
scale_factor
).
repeat
((
1
,
2
))
if
cfg
.
get
(
'yolox_style'
,
False
):
# do not need max_per_img
cfg
.
max_per_img
=
len
(
results
)
results
=
self
.
_bbox_post_process
(
results
=
results
,
cfg
=
cfg
,
rescale
=
False
,
with_nms
=
with_nms
,
img_meta
=
img_meta
)
results
.
bboxes
[:,
0
::
2
].
clamp_
(
0
,
ori_shape
[
1
])
results
.
bboxes
[:,
1
::
2
].
clamp_
(
0
,
ori_shape
[
0
])
results_list
.
append
(
results
)
return
results_list
yolo_world/models/dense_heads/yolo_world_seg_head.py
0 → 100644
View file @
e9cee049
# Copyright (c) Lin Song. All rights reserved.
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
,
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmcv.cnn
import
ConvModule
from
mmengine.config
import
ConfigDict
from
mmengine.dist
import
get_dist_info
from
mmengine.structures
import
InstanceData
from
mmdet.structures
import
SampleList
from
mmdet.utils
import
(
ConfigType
,
OptConfigType
,
OptInstanceList
,
OptMultiConfig
,
InstanceList
)
from
mmdet.models.utils
import
multi_apply
,
unpack_gt_instances
from
mmyolo.models.dense_heads
import
YOLOv8HeadModule
from
mmyolo.models.utils
import
gt_instances_preprocess
from
mmyolo.registry
import
MODELS
,
TASK_UTILS
from
mmyolo.models.dense_heads.yolov5_ins_head
import
(
ProtoModule
,
YOLOv5InsHead
)
from
.yolo_world_head
import
ContrastiveHead
,
BNContrastiveHead
@
MODELS
.
register_module
()
class
YOLOWorldSegHeadModule
(
YOLOv8HeadModule
):
def
__init__
(
self
,
*
args
,
embed_dims
:
int
,
proto_channels
:
int
,
mask_channels
:
int
,
freeze_bbox
:
bool
=
False
,
freeze_all
:
bool
=
False
,
use_bn_head
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
embed_dims
=
embed_dims
self
.
proto_channels
=
proto_channels
self
.
mask_channels
=
mask_channels
self
.
freeze_bbox
=
freeze_bbox
self
.
freeze_all
=
freeze_all
self
.
use_bn_head
=
use_bn_head
super
().
__init__
(
*
args
,
**
kwargs
)
def
init_weights
(
self
,
prior_prob
=
0.01
):
"""Initialize the weight and bias of PPYOLOE head."""
super
().
init_weights
()
for
cls_pred
,
cls_contrast
,
stride
in
zip
(
self
.
cls_preds
,
self
.
cls_contrasts
,
self
.
featmap_strides
):
cls_pred
[
-
1
].
bias
.
data
[:]
=
0.0
# reset bias
if
hasattr
(
cls_contrast
,
'bias'
):
nn
.
init
.
constant_
(
cls_contrast
.
bias
.
data
,
math
.
log
(
5
/
self
.
num_classes
/
(
640
/
stride
)
**
2
))
def
_init_layers
(
self
)
->
None
:
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self
.
cls_preds
=
nn
.
ModuleList
()
self
.
reg_preds
=
nn
.
ModuleList
()
self
.
seg_preds
=
nn
.
ModuleList
()
self
.
cls_contrasts
=
nn
.
ModuleList
()
reg_out_channels
=
max
(
(
16
,
self
.
in_channels
[
0
]
//
4
,
self
.
reg_max
*
4
))
seg_out_channels
=
max
(
self
.
in_channels
[
0
]
//
4
,
self
.
mask_channels
)
cls_out_channels
=
max
(
self
.
in_channels
[
0
],
self
.
num_classes
)
bbox_norm_cfg
=
self
.
norm_cfg
bbox_norm_cfg
[
'requires_grad'
]
=
not
self
.
freeze_bbox
if
self
.
freeze_all
:
self
.
norm_cfg
[
'requires_grad'
]
=
False
bbox_norm_cfg
[
'requires_grad'
]
=
False
for
i
in
range
(
self
.
num_levels
):
self
.
reg_preds
.
append
(
nn
.
Sequential
(
ConvModule
(
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
reg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
bbox_norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
in_channels
=
reg_out_channels
,
out_channels
=
reg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
bbox_norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
Conv2d
(
in_channels
=
reg_out_channels
,
out_channels
=
4
*
self
.
reg_max
,
kernel_size
=
1
)))
self
.
cls_preds
.
append
(
nn
.
Sequential
(
ConvModule
(
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
cls_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
bbox_norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
in_channels
=
cls_out_channels
,
out_channels
=
cls_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
bbox_norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
Conv2d
(
in_channels
=
cls_out_channels
,
out_channels
=
self
.
embed_dims
,
kernel_size
=
1
)))
self
.
seg_preds
.
append
(
nn
.
Sequential
(
ConvModule
(
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
seg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
in_channels
=
seg_out_channels
,
out_channels
=
seg_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
Conv2d
(
in_channels
=
seg_out_channels
,
out_channels
=
self
.
mask_channels
,
kernel_size
=
1
)))
if
self
.
use_bn_head
:
self
.
cls_contrasts
.
append
(
BNContrastiveHead
(
self
.
embed_dims
,
self
.
norm_cfg
))
else
:
self
.
cls_contrasts
.
append
(
ContrastiveHead
(
self
.
embed_dims
))
proj
=
torch
.
arange
(
self
.
reg_max
,
dtype
=
torch
.
float
)
self
.
register_buffer
(
'proj'
,
proj
,
persistent
=
False
)
self
.
proto_pred
=
ProtoModule
(
in_channels
=
self
.
in_channels
[
0
],
middle_channels
=
self
.
proto_channels
,
mask_channels
=
self
.
mask_channels
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
)
if
self
.
freeze_bbox
or
self
.
freeze_bbox
:
self
.
_freeze_all
()
def
_freeze_all
(
self
):
frozen_list
=
[
self
.
cls_preds
,
self
.
reg_preds
,
self
.
cls_contrasts
]
if
self
.
freeze_all
:
frozen_list
.
extend
([
self
.
proto_pred
,
self
.
seg_preds
])
for
module
in
frozen_list
:
for
m
in
module
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
:
bool
=
True
):
"""Convert the model into training mode while keep normalization layer
frozen."""
super
().
train
(
mode
)
if
self
.
freeze_bbox
or
self
.
freeze_all
:
self
.
_freeze_all
()
def
forward
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
)
->
Tuple
[
List
]:
"""Forward features from the upstream network."""
assert
len
(
img_feats
)
==
self
.
num_levels
txt_feats
=
[
txt_feats
for
_
in
range
(
self
.
num_levels
)]
mask_protos
=
self
.
proto_pred
(
img_feats
[
0
])
cls_logit
,
bbox_preds
,
bbox_dist_preds
,
coeff_preds
=
multi_apply
(
self
.
forward_single
,
img_feats
,
txt_feats
,
self
.
cls_preds
,
self
.
reg_preds
,
self
.
cls_contrasts
,
self
.
seg_preds
)
if
self
.
training
:
return
cls_logit
,
bbox_preds
,
bbox_dist_preds
,
coeff_preds
,
mask_protos
else
:
return
cls_logit
,
bbox_preds
,
None
,
coeff_preds
,
mask_protos
def
forward_single
(
self
,
img_feat
:
Tensor
,
txt_feat
:
Tensor
,
cls_pred
:
nn
.
ModuleList
,
reg_pred
:
nn
.
ModuleList
,
cls_contrast
:
nn
.
ModuleList
,
seg_pred
:
nn
.
ModuleList
)
->
Tuple
:
"""Forward feature of a single scale level."""
b
,
_
,
h
,
w
=
img_feat
.
shape
cls_embed
=
cls_pred
(
img_feat
)
cls_logit
=
cls_contrast
(
cls_embed
,
txt_feat
)
bbox_dist_preds
=
reg_pred
(
img_feat
)
coeff_pred
=
seg_pred
(
img_feat
)
if
self
.
reg_max
>
1
:
bbox_dist_preds
=
bbox_dist_preds
.
reshape
(
[
-
1
,
4
,
self
.
reg_max
,
h
*
w
]).
permute
(
0
,
3
,
1
,
2
)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds
=
bbox_dist_preds
.
softmax
(
3
).
matmul
(
self
.
proj
.
view
([
-
1
,
1
])).
squeeze
(
-
1
)
bbox_preds
=
bbox_preds
.
transpose
(
1
,
2
).
reshape
(
b
,
-
1
,
h
,
w
)
else
:
bbox_preds
=
bbox_dist_preds
if
self
.
training
:
return
cls_logit
,
bbox_preds
,
bbox_dist_preds
,
coeff_pred
else
:
return
cls_logit
,
bbox_preds
,
None
,
coeff_pred
@
MODELS
.
register_module
()
class
YOLOWorldSegHead
(
YOLOv5InsHead
):
def
__init__
(
self
,
head_module
:
ConfigType
,
prior_generator
:
ConfigType
=
dict
(
type
=
'mmdet.MlvlPointGenerator'
,
offset
=
0.5
,
strides
=
[
8
,
16
,
32
]),
bbox_coder
:
ConfigType
=
dict
(
type
=
'DistancePointBBoxCoder'
),
loss_cls
:
ConfigType
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
True
,
reduction
=
'none'
,
loss_weight
=
0.5
),
loss_bbox
:
ConfigType
=
dict
(
type
=
'IoULoss'
,
iou_mode
=
'ciou'
,
bbox_format
=
'xyxy'
,
reduction
=
'sum'
,
loss_weight
=
7.5
,
return_iou
=
False
),
loss_dfl
=
dict
(
type
=
'mmdet.DistributionFocalLoss'
,
reduction
=
'mean'
,
loss_weight
=
1.5
/
4
),
mask_overlap
:
bool
=
True
,
loss_mask
:
ConfigType
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
True
,
reduction
=
'none'
),
loss_mask_weight
=
0.05
,
train_cfg
:
OptConfigType
=
None
,
test_cfg
:
OptConfigType
=
None
,
init_cfg
:
OptMultiConfig
=
None
):
super
().
__init__
(
head_module
=
head_module
,
prior_generator
=
prior_generator
,
bbox_coder
=
bbox_coder
,
loss_cls
=
loss_cls
,
loss_bbox
=
loss_bbox
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
init_cfg
=
init_cfg
)
self
.
loss_dfl
=
MODELS
.
build
(
loss_dfl
)
self
.
loss_obj
=
None
self
.
mask_overlap
=
mask_overlap
self
.
loss_mask
:
nn
.
Module
=
MODELS
.
build
(
loss_mask
)
self
.
loss_mask_weight
=
loss_mask_weight
def
special_init
(
self
):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
if
self
.
train_cfg
:
self
.
assigner
=
TASK_UTILS
.
build
(
self
.
train_cfg
.
assigner
)
# Add common attributes to reduce calculation
self
.
featmap_sizes_train
=
None
self
.
num_level_priors
=
None
self
.
flatten_priors_train
=
None
self
.
stride_tensor
=
None
"""YOLO World head."""
def
loss
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
Union
[
list
,
dict
])
->
dict
:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
outs
=
self
(
img_feats
,
txt_feats
)
# Fast version
loss_inputs
=
outs
+
(
batch_data_samples
[
'bboxes_labels'
],
batch_data_samples
[
'masks'
],
batch_data_samples
[
'img_metas'
])
losses
=
self
.
loss_by_feat
(
*
loss_inputs
)
return
losses
def
loss_and_predict
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
SampleList
,
proposal_cfg
:
Optional
[
ConfigDict
]
=
None
)
->
Tuple
[
dict
,
InstanceList
]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
"""
outputs
=
unpack_gt_instances
(
batch_data_samples
)
(
batch_gt_instances
,
batch_gt_instances_ignore
,
batch_img_metas
)
=
outputs
outs
=
self
(
img_feats
,
txt_feats
)
loss_inputs
=
outs
+
(
batch_gt_instances
,
batch_img_metas
,
batch_gt_instances_ignore
)
losses
=
self
.
loss_by_feat
(
*
loss_inputs
)
predictions
=
self
.
predict_by_feat
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
cfg
=
proposal_cfg
)
return
losses
,
predictions
def
forward
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
)
->
Tuple
[
List
]:
"""Forward features from the upstream network."""
return
self
.
head_module
(
img_feats
,
txt_feats
)
def
predict
(
self
,
img_feats
:
Tuple
[
Tensor
],
txt_feats
:
Tensor
,
batch_data_samples
:
SampleList
,
rescale
:
bool
=
False
)
->
InstanceList
:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
"""
batch_img_metas
=
[
data_samples
.
metainfo
for
data_samples
in
batch_data_samples
]
outs
=
self
(
img_feats
,
txt_feats
)
predictions
=
self
.
predict_by_feat
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
rescale
=
rescale
)
return
predictions
def
aug_test
(
self
,
aug_batch_feats
,
aug_batch_img_metas
,
rescale
=
False
,
with_ori_nms
=
False
,
**
kwargs
):
"""Test function with test time augmentation."""
raise
NotImplementedError
(
'aug_test is not implemented yet.'
)
def
loss_by_feat
(
self
,
cls_scores
:
Sequence
[
Tensor
],
bbox_preds
:
Sequence
[
Tensor
],
bbox_dist_preds
:
Sequence
[
Tensor
],
coeff_preds
:
Sequence
[
Tensor
],
proto_preds
:
Tensor
,
batch_gt_instances
:
Sequence
[
InstanceData
],
batch_gt_masks
:
Sequence
[
Tensor
],
batch_img_metas
:
Sequence
[
dict
],
batch_gt_instances_ignore
:
OptInstanceList
=
None
)
->
dict
:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
num_imgs
=
len
(
batch_img_metas
)
current_featmap_sizes
=
[
cls_score
.
shape
[
2
:]
for
cls_score
in
cls_scores
]
# If the shape does not equal, generate new one
if
current_featmap_sizes
!=
self
.
featmap_sizes_train
:
self
.
featmap_sizes_train
=
current_featmap_sizes
mlvl_priors_with_stride
=
self
.
prior_generator
.
grid_priors
(
self
.
featmap_sizes_train
,
dtype
=
cls_scores
[
0
].
dtype
,
device
=
cls_scores
[
0
].
device
,
with_stride
=
True
)
self
.
num_level_priors
=
[
len
(
n
)
for
n
in
mlvl_priors_with_stride
]
self
.
flatten_priors_train
=
torch
.
cat
(
mlvl_priors_with_stride
,
dim
=
0
)
self
.
stride_tensor
=
self
.
flatten_priors_train
[...,
[
2
]]
# gt info
gt_info
=
gt_instances_preprocess
(
batch_gt_instances
,
num_imgs
)
gt_labels
=
gt_info
[:,
:,
:
1
]
gt_bboxes
=
gt_info
[:,
:,
1
:]
# xyxy
pad_bbox_flag
=
(
gt_bboxes
.
sum
(
-
1
,
keepdim
=
True
)
>
0
).
float
()
# pred info
flatten_cls_preds
=
[
cls_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
self
.
num_classes
)
for
cls_pred
in
cls_scores
]
flatten_pred_bboxes
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
# (bs, n, 4 * reg_max)
flatten_pred_dists
=
[
bbox_pred_org
.
reshape
(
num_imgs
,
-
1
,
self
.
head_module
.
reg_max
*
4
)
for
bbox_pred_org
in
bbox_dist_preds
]
flatten_pred_coeffs
=
[
coeff_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_imgs
,
-
1
,
self
.
head_module
.
mask_channels
)
for
coeff_pred
in
coeff_preds
]
flatten_dist_preds
=
torch
.
cat
(
flatten_pred_dists
,
dim
=
1
)
flatten_cls_preds
=
torch
.
cat
(
flatten_cls_preds
,
dim
=
1
)
flatten_pred_bboxes
=
torch
.
cat
(
flatten_pred_bboxes
,
dim
=
1
)
flatten_pred_bboxes
=
self
.
bbox_coder
.
decode
(
self
.
flatten_priors_train
[...,
:
2
],
flatten_pred_bboxes
,
self
.
stride_tensor
[...,
0
])
flatten_pred_coeffs
=
torch
.
cat
(
flatten_pred_coeffs
,
dim
=
1
)
assigned_result
=
self
.
assigner
(
(
flatten_pred_bboxes
.
detach
()).
type
(
gt_bboxes
.
dtype
),
flatten_cls_preds
.
detach
().
sigmoid
(),
self
.
flatten_priors_train
,
gt_labels
,
gt_bboxes
,
pad_bbox_flag
)
assigned_bboxes
=
assigned_result
[
'assigned_bboxes'
]
assigned_scores
=
assigned_result
[
'assigned_scores'
]
fg_mask_pre_prior
=
assigned_result
[
'fg_mask_pre_prior'
]
assigned_gt_idxs
=
assigned_result
[
'assigned_gt_idxs'
]
assigned_scores_sum
=
assigned_scores
.
sum
().
clamp
(
min
=
1
)
loss_cls
=
self
.
loss_cls
(
flatten_cls_preds
,
assigned_scores
).
sum
()
loss_cls
/=
assigned_scores_sum
# rescale bbox
assigned_bboxes
/=
self
.
stride_tensor
flatten_pred_bboxes
/=
self
.
stride_tensor
# select positive samples mask
num_pos
=
fg_mask_pre_prior
.
sum
()
if
num_pos
>
0
:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask
=
fg_mask_pre_prior
.
unsqueeze
(
-
1
).
repeat
([
1
,
1
,
4
])
pred_bboxes_pos
=
torch
.
masked_select
(
flatten_pred_bboxes
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
assigned_bboxes_pos
=
torch
.
masked_select
(
assigned_bboxes
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
bbox_weight
=
torch
.
masked_select
(
assigned_scores
.
sum
(
-
1
),
fg_mask_pre_prior
).
unsqueeze
(
-
1
)
loss_bbox
=
self
.
loss_bbox
(
pred_bboxes_pos
,
assigned_bboxes_pos
,
weight
=
bbox_weight
)
/
assigned_scores_sum
# dfl loss
pred_dist_pos
=
flatten_dist_preds
[
fg_mask_pre_prior
]
assigned_ltrb
=
self
.
bbox_coder
.
encode
(
self
.
flatten_priors_train
[...,
:
2
]
/
self
.
stride_tensor
,
assigned_bboxes
,
max_dis
=
self
.
head_module
.
reg_max
-
1
,
eps
=
0.01
)
assigned_ltrb_pos
=
torch
.
masked_select
(
assigned_ltrb
,
prior_bbox_mask
).
reshape
([
-
1
,
4
])
loss_dfl
=
self
.
loss_dfl
(
pred_dist_pos
.
reshape
(
-
1
,
self
.
head_module
.
reg_max
),
assigned_ltrb_pos
.
reshape
(
-
1
),
weight
=
bbox_weight
.
expand
(
-
1
,
4
).
reshape
(
-
1
),
avg_factor
=
assigned_scores_sum
)
_
,
c
,
mask_h
,
mask_w
=
proto_preds
.
shape
if
batch_gt_masks
.
shape
[
-
2
:]
!=
(
mask_h
,
mask_w
):
batch_gt_masks
=
F
.
interpolate
(
batch_gt_masks
[
None
],
(
mask_h
,
mask_w
),
mode
=
'nearest'
)[
0
]
loss_mask
=
torch
.
zeros
(
1
,
device
=
loss_dfl
.
device
)
box_sum_flag
=
pad_bbox_flag
.
long
().
sum
(
dim
=
1
).
squeeze
(
1
)
batch_inds
=
torch
.
zeros
(
num_imgs
,
dtype
=
torch
.
int64
,
device
=
assigned_gt_idxs
.
device
)[:,
None
]
batch_inds
[
1
:]
=
box_sum_flag
.
cumsum
(
dim
=
0
)[:
-
1
][...,
None
]
_assigned_gt_idxs
=
assigned_gt_idxs
+
batch_inds
for
bs
in
range
(
num_imgs
):
# 8400
bbox_match_inds
=
assigned_gt_idxs
[
bs
]
mask_match_inds
=
_assigned_gt_idxs
[
bs
]
bbox_match_inds
=
torch
.
masked_select
(
bbox_match_inds
,
fg_mask_pre_prior
[
bs
])
mask_match_inds
=
torch
.
masked_select
(
mask_match_inds
,
fg_mask_pre_prior
[
bs
])
# mask
mask_dim
=
coeff_preds
[
0
].
shape
[
1
]
prior_mask_mask
=
fg_mask_pre_prior
[
bs
].
unsqueeze
(
-
1
).
repeat
(
[
1
,
mask_dim
])
pred_coeffs_pos
=
torch
.
masked_select
(
flatten_pred_coeffs
[
bs
],
prior_mask_mask
).
reshape
(
[
-
1
,
mask_dim
])
match_boxes
=
gt_bboxes
[
bs
][
bbox_match_inds
]
/
4
normed_boxes
=
gt_bboxes
[
bs
][
bbox_match_inds
]
/
640
bbox_area
=
(
normed_boxes
[:,
2
:]
-
normed_boxes
[:,
:
2
]).
prod
(
dim
=
1
)
if
not
mask_match_inds
.
any
():
continue
assert
not
self
.
mask_overlap
mask_gti
=
batch_gt_masks
[
mask_match_inds
]
mask_preds
=
(
pred_coeffs_pos
@
proto_preds
[
bs
].
view
(
c
,
-
1
)).
view
(
-
1
,
mask_h
,
mask_w
)
loss_mask_full
=
self
.
loss_mask
(
mask_preds
,
mask_gti
)
_loss_mask
=
(
self
.
crop_mask
(
loss_mask_full
[
None
],
match_boxes
).
mean
(
dim
=
(
2
,
3
))
/
bbox_area
)
loss_mask
+=
_loss_mask
.
mean
()
else
:
loss_bbox
=
flatten_pred_bboxes
.
sum
()
*
0
loss_dfl
=
flatten_pred_bboxes
.
sum
()
*
0
loss_mask
=
flatten_pred_coeffs
.
sum
()
*
0
_
,
world_size
=
get_dist_info
()
return
dict
(
loss_cls
=
loss_cls
*
num_imgs
*
world_size
,
loss_bbox
=
loss_bbox
*
num_imgs
*
world_size
,
loss_dfl
=
loss_dfl
*
num_imgs
*
world_size
,
loss_mask
=
loss_mask
*
self
.
loss_mask_weight
*
world_size
)
yolo_world/models/detectors/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
.yolo_world
import
YOLOWorldDetector
,
SimpleYOLOWorldDetector
__all__
=
[
'YOLOWorldDetector'
,
'SimpleYOLOWorldDetector'
]
yolo_world/models/detectors/yolo_world.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
from
typing
import
List
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
mmdet.structures
import
OptSampleList
,
SampleList
from
mmyolo.models.detectors
import
YOLODetector
from
mmyolo.registry
import
MODELS
@
MODELS
.
register_module
()
class
YOLOWorldDetector
(
YOLODetector
):
"""Implementation of YOLOW Series"""
def
__init__
(
self
,
*
args
,
mm_neck
:
bool
=
False
,
num_train_classes
=
80
,
num_test_classes
=
80
,
**
kwargs
)
->
None
:
self
.
mm_neck
=
mm_neck
self
.
num_train_classes
=
num_train_classes
self
.
num_test_classes
=
num_test_classes
super
().
__init__
(
*
args
,
**
kwargs
)
def
loss
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
Union
[
dict
,
list
]:
"""Calculate losses from a batch of inputs and data samples."""
self
.
bbox_head
.
num_classes
=
self
.
num_train_classes
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
losses
=
self
.
bbox_head
.
loss
(
img_feats
,
txt_feats
,
batch_data_samples
)
return
losses
def
predict
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
,
rescale
:
bool
=
True
)
->
SampleList
:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
# self.bbox_head.num_classes = self.num_test_classes
self
.
bbox_head
.
num_classes
=
txt_feats
[
0
].
shape
[
0
]
results_list
=
self
.
bbox_head
.
predict
(
img_feats
,
txt_feats
,
batch_data_samples
,
rescale
=
rescale
)
batch_data_samples
=
self
.
add_pred_to_datasample
(
batch_data_samples
,
results_list
)
return
batch_data_samples
def
reparameterize
(
self
,
texts
:
List
[
List
[
str
]])
->
None
:
# encode text embeddings into the detector
self
.
texts
=
texts
self
.
text_feats
=
self
.
backbone
.
forward_text
(
texts
)
def
_forward
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
OptSampleList
=
None
)
->
Tuple
[
List
[
Tensor
]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
results
=
self
.
bbox_head
.
forward
(
img_feats
,
txt_feats
)
return
results
def
extract_feat
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
Tuple
[
Tuple
[
Tensor
],
Tensor
]:
"""Extract features."""
txt_feats
=
None
if
batch_data_samples
is
None
:
texts
=
self
.
texts
txt_feats
=
self
.
text_feats
elif
isinstance
(
batch_data_samples
,
dict
)
and
'texts'
in
batch_data_samples
:
texts
=
batch_data_samples
[
'texts'
]
elif
isinstance
(
batch_data_samples
,
list
)
and
hasattr
(
batch_data_samples
[
0
],
'texts'
):
texts
=
[
data_sample
.
texts
for
data_sample
in
batch_data_samples
]
elif
hasattr
(
self
,
'text_feats'
):
texts
=
self
.
texts
txt_feats
=
self
.
text_feats
else
:
raise
TypeError
(
'batch_data_samples should be dict or list.'
)
if
txt_feats
is
not
None
:
# forward image only
img_feats
=
self
.
backbone
.
forward_image
(
batch_inputs
)
else
:
img_feats
,
txt_feats
=
self
.
backbone
(
batch_inputs
,
texts
)
if
self
.
with_neck
:
if
self
.
mm_neck
:
img_feats
=
self
.
neck
(
img_feats
,
txt_feats
)
else
:
img_feats
=
self
.
neck
(
img_feats
)
return
img_feats
,
txt_feats
@
MODELS
.
register_module
()
class
SimpleYOLOWorldDetector
(
YOLODetector
):
"""Implementation of YOLO World Series"""
def
__init__
(
self
,
*
args
,
mm_neck
:
bool
=
False
,
num_train_classes
=
80
,
num_test_classes
=
80
,
prompt_dim
=
512
,
num_prompts
=
80
,
embedding_path
=
''
,
reparameterized
=
False
,
freeze_prompt
=
False
,
use_mlp_adapter
=
False
,
**
kwargs
)
->
None
:
self
.
mm_neck
=
mm_neck
self
.
num_training_classes
=
num_train_classes
self
.
num_test_classes
=
num_test_classes
self
.
prompt_dim
=
prompt_dim
self
.
num_prompts
=
num_prompts
self
.
reparameterized
=
reparameterized
self
.
freeze_prompt
=
freeze_prompt
self
.
use_mlp_adapter
=
use_mlp_adapter
super
().
__init__
(
*
args
,
**
kwargs
)
if
not
self
.
reparameterized
:
if
len
(
embedding_path
)
>
0
:
import
numpy
as
np
self
.
embeddings
=
torch
.
nn
.
Parameter
(
torch
.
from_numpy
(
np
.
load
(
embedding_path
)).
float
())
else
:
# random init
embeddings
=
nn
.
functional
.
normalize
(
torch
.
randn
(
(
num_prompts
,
prompt_dim
)),
dim
=-
1
)
self
.
embeddings
=
nn
.
Parameter
(
embeddings
)
if
self
.
freeze_prompt
:
self
.
embeddings
.
requires_grad
=
False
else
:
self
.
embeddings
.
requires_grad
=
True
if
use_mlp_adapter
:
self
.
adapter
=
nn
.
Sequential
(
nn
.
Linear
(
prompt_dim
,
prompt_dim
*
2
),
nn
.
ReLU
(
True
),
nn
.
Linear
(
prompt_dim
*
2
,
prompt_dim
))
else
:
self
.
adapter
=
None
def
loss
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
Union
[
dict
,
list
]:
"""Calculate losses from a batch of inputs and data samples."""
self
.
bbox_head
.
num_classes
=
self
.
num_training_classes
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
if
self
.
reparameterized
:
losses
=
self
.
bbox_head
.
loss
(
img_feats
,
batch_data_samples
)
else
:
losses
=
self
.
bbox_head
.
loss
(
img_feats
,
txt_feats
,
batch_data_samples
)
return
losses
def
predict
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
,
rescale
:
bool
=
True
)
->
SampleList
:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
self
.
bbox_head
.
num_classes
=
self
.
num_test_classes
if
self
.
reparameterized
:
results_list
=
self
.
bbox_head
.
predict
(
img_feats
,
batch_data_samples
,
rescale
=
rescale
)
else
:
results_list
=
self
.
bbox_head
.
predict
(
img_feats
,
txt_feats
,
batch_data_samples
,
rescale
=
rescale
)
batch_data_samples
=
self
.
add_pred_to_datasample
(
batch_data_samples
,
results_list
)
return
batch_data_samples
def
_forward
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
OptSampleList
=
None
)
->
Tuple
[
List
[
Tensor
]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats
,
txt_feats
=
self
.
extract_feat
(
batch_inputs
,
batch_data_samples
)
if
self
.
reparameterized
:
results
=
self
.
bbox_head
.
forward
(
img_feats
)
else
:
results
=
self
.
bbox_head
.
forward
(
img_feats
,
txt_feats
)
return
results
def
extract_feat
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
Tuple
[
Tuple
[
Tensor
],
Tensor
]:
"""Extract features."""
# only image features
img_feats
,
_
=
self
.
backbone
(
batch_inputs
,
None
)
if
not
self
.
reparameterized
:
# use embeddings
txt_feats
=
self
.
embeddings
[
None
]
if
self
.
adapter
is
not
None
:
txt_feats
=
self
.
adapter
(
txt_feats
)
+
txt_feats
txt_feats
=
nn
.
functional
.
normalize
(
txt_feats
,
dim
=-
1
,
p
=
2
)
txt_feats
=
txt_feats
.
repeat
(
img_feats
[
0
].
shape
[
0
],
1
,
1
)
else
:
txt_feats
=
None
if
self
.
with_neck
:
if
self
.
mm_neck
:
img_feats
=
self
.
neck
(
img_feats
,
txt_feats
)
else
:
img_feats
=
self
.
neck
(
img_feats
)
return
img_feats
,
txt_feats
yolo_world/models/layers/__init__.py
0 → 100644
View file @
e9cee049
# Copyright (c) Tencent Inc. All rights reserved.
# Basic brick modules for PAFPN based on CSPLayers
from
.yolo_bricks
import
(
CSPLayerWithTwoConv
,
MaxSigmoidAttnBlock
,
MaxSigmoidCSPLayerWithTwoConv
,
ImagePoolingAttentionModule
,
RepConvMaxSigmoidCSPLayerWithTwoConv
,
RepMaxSigmoidCSPLayerWithTwoConv
)
__all__
=
[
'CSPLayerWithTwoConv'
,
'MaxSigmoidAttnBlock'
,
'MaxSigmoidCSPLayerWithTwoConv'
,
'RepConvMaxSigmoidCSPLayerWithTwoConv'
,
'RepMaxSigmoidCSPLayerWithTwoConv'
,
'ImagePoolingAttentionModule'
]
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment