Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
f3b13cad
Commit
f3b13cad
authored
May 17, 2023
by
yeshenglong1
Browse files
UpDate README.md
parent
0797920d
Changes
102
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3917 additions
and
3917 deletions
+3917
-3917
autonomous_driving/Online-HD-Map-Construction/src/models/heads/dg_head.py
...ng/Online-HD-Map-Construction/src/models/heads/dg_head.py
+305
-305
autonomous_driving/Online-HD-Map-Construction/src/models/heads/map_element_detector.py
...Map-Construction/src/models/heads/map_element_detector.py
+627
-627
autonomous_driving/Online-HD-Map-Construction/src/models/heads/polyline_generator.py
...D-Map-Construction/src/models/heads/polyline_generator.py
+525
-525
autonomous_driving/Online-HD-Map-Construction/src/models/losses/__init__.py
.../Online-HD-Map-Construction/src/models/losses/__init__.py
+2
-2
autonomous_driving/Online-HD-Map-Construction/src/models/losses/detr_loss.py
...Online-HD-Map-Construction/src/models/losses/detr_loss.py
+169
-169
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/__init__.py
.../Online-HD-Map-Construction/src/models/mapers/__init__.py
+1
-1
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/base_mapper.py
...line-HD-Map-Construction/src/models/mapers/base_mapper.py
+148
-148
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/vectormapnet.py
...ine-HD-Map-Construction/src/models/mapers/vectormapnet.py
+260
-260
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/__init__.py
...Map-Construction/src/models/transformer_utils/__init__.py
+1
-1
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/base_transformer.py
...truction/src/models/transformer_utils/base_transformer.py
+24
-24
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/deformable_transformer.py
...on/src/models/transformer_utils/deformable_transformer.py
+368
-368
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/fp16_dattn.py
...p-Construction/src/models/transformer_utils/fp16_dattn.py
+401
-401
autonomous_driving/Online-HD-Map-Construction/tools/dist_test.sh
...ous_driving/Online-HD-Map-Construction/tools/dist_test.sh
+10
-10
autonomous_driving/Online-HD-Map-Construction/tools/dist_train.sh
...us_driving/Online-HD-Map-Construction/tools/dist_train.sh
+9
-9
autonomous_driving/Online-HD-Map-Construction/tools/evaluate_submission.py
...g/Online-HD-Map-Construction/tools/evaluate_submission.py
+28
-28
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_test.py
...us_driving/Online-HD-Map-Construction/tools/mmdet_test.py
+190
-190
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_train.py
...s_driving/Online-HD-Map-Construction/tools/mmdet_train.py
+170
-170
autonomous_driving/Online-HD-Map-Construction/tools/test.py
autonomous_driving/Online-HD-Map-Construction/tools/test.py
+196
-196
autonomous_driving/Online-HD-Map-Construction/tools/train.py
autonomous_driving/Online-HD-Map-Construction/tools/train.py
+261
-261
autonomous_driving/Online-HD-Map-Construction/tools/visualization/renderer.py
...nline-HD-Map-Construction/tools/visualization/renderer.py
+222
-222
No files found.
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/dg_head.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/dg_head.py
View file @
f3b13cad
import
copy
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
force_fp32
from
mmdet.models
import
HEADS
,
build_head
,
build_loss
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
.base_map_head
import
BaseMapHead
import
numpy
as
np
from
..augmentation.sythesis_det
import
NoiseSythesis
@
HEADS
.
register_module
(
force
=
True
)
class
DGHead
(
BaseMapHead
):
def
__init__
(
self
,
det_net_cfg
=
dict
(),
gen_net_cfg
=
dict
(),
loss_vert
=
dict
(),
loss_face
=
dict
(),
max_num_vertices
=
90
,
top_p_gen_model
=
0.9
,
sync_cls_avg_factor
=
True
,
augmentation
=
False
,
augmentation_kwargs
=
None
,
joint_training
=
False
,
**
kwargs
):
super
().
__init__
()
# Heads
self
.
det_net
=
build_head
(
det_net_cfg
)
self
.
gen_net
=
build_head
(
gen_net_cfg
)
self
.
coord_dim
=
self
.
gen_net
.
coord_dim
# Loss params
self
.
bg_cls_weight
=
1.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
max_num_vertices
=
max_num_vertices
self
.
top_p_gen_model
=
top_p_gen_model
self
.
fp16_enabled
=
False
self
.
augmentation
=
None
if
augmentation
:
augmentation_kwargs
.
update
({
'canvas_size'
:
gen_net_cfg
.
canvas_size
})
self
.
augmentation
=
NoiseSythesis
(
**
augmentation_kwargs
)
self
.
joint_training
=
joint_training
def
forward
(
self
,
batch
,
img_metas
=
None
,
**
kwargs
):
'''
Args:
Returns:
outs (Dict):
'''
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
only_det
=
False
,
**
kwargs
):
''' we use teacher force strategy'''
bbox_dict
=
self
.
det_net
(
context
=
context
)
outs
=
dict
(
bbox
=
bbox_dict
,
)
losses_dict
,
det_match_idxs
,
det_match_gt_idxs
=
\
self
.
loss_det
(
batch
,
outs
)
if
only_det
:
return
outs
,
losses_dict
if
self
.
augmentation
is
not
None
:
polylines
,
bbox_flat
=
\
self
.
augmentation
(
batch
[
'gen'
],
simple_aug
=
True
)
if
bbox_flat
is
None
:
bbox_flat
=
batch
[
'gen'
][
'bbox_flat'
]
gen_input
=
dict
(
lines_bs_idx
=
batch
[
'gen'
][
'lines_bs_idx'
],
lines_cls
=
batch
[
'gen'
][
'lines_cls'
],
bbox_flat
=
bbox_flat
,
polylines
=
polylines
,
polyline_masks
=
batch
[
'gen'
][
'polyline_masks'
]
)
else
:
gen_input
=
batch
[
'gen'
]
if
self
.
joint_training
:
# for down stream polyline
if
'lines'
in
bbox_dict
[
-
1
]:
# for fix anchor
pred_bbox
=
bbox_dict
[
-
1
][
'lines'
].
detach
()
elif
'bboxs'
in
bbox_dict
[
-
1
]:
# for rpv
pred_bbox
=
bbox_dict
[
-
1
][
'bboxs'
].
detach
()
else
:
raise
NotImplementedError
# changed to original gt order.
det_match_idx
=
det_match_idxs
[
-
1
]
det_match_gt_idx
=
det_match_gt_idxs
[
-
1
]
_bboxs
=
[]
for
i
,
(
match_idx
,
bbox
)
in
enumerate
(
zip
(
det_match_idx
,
pred_bbox
)):
_bboxs
.
append
(
bbox
[
match_idx
])
_bboxs
[
-
1
]
=
_bboxs
[
-
1
][
torch
.
argsort
(
det_match_gt_idx
[
i
])]
_bboxs
=
torch
.
cat
(
_bboxs
,
dim
=
0
)
# quantize the data
_bboxs
=
\
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
# gen_input['bbox_flat'] = _bboxs
remain_idx
=
torch
.
randperm
(
_bboxs
.
shape
[
0
])[:
int
(
_bboxs
.
shape
[
0
]
*
0.2
)]
# for data efficient
for
k
in
gen_input
.
keys
():
if
k
==
'bbox_flat'
:
gen_input
[
k
]
=
torch
.
cat
((
_bboxs
,
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
else
:
gen_input
[
k
]
=
torch
.
cat
((
gen_input
[
k
],
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
gen_input
,
context
=
context
)
outs
.
update
(
dict
(
polylines
=
poly_dict
,
))
if
self
.
joint_training
:
for
k
in
batch
[
'gen'
].
keys
():
batch
[
'gen'
][
k
]
=
\
torch
.
cat
((
batch
[
'gen'
][
k
],
batch
[
'gen'
][
k
][
remain_idx
]),
dim
=
0
)
gen_losses_dict
=
\
self
.
loss_gen
(
batch
,
outs
)
losses_dict
.
update
(
gen_losses_dict
)
return
outs
,
losses_dict
def
loss_det
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
# det
det_loss_dict
,
det_match_idx
,
det_match_gt_idx
=
\
self
.
det_net
.
loss
(
gt
[
'det'
],
pred
[
'bbox'
])
for
k
,
v
in
det_loss_dict
.
items
():
loss_dict
[
'det_'
+
k
]
=
v
return
loss_dict
,
det_match_idx
,
det_match_gt_idx
def
loss_gen
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
# gen
gen_loss_dict
=
self
.
gen_net
.
loss
(
gt
[
'gen'
],
pred
[
'polylines'
])
for
k
,
v
in
gen_loss_dict
.
items
():
loss_dict
[
'gen_'
+
k
]
=
v
return
loss_dict
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
pass
@
torch
.
no_grad
()
def
inference
(
self
,
batch
:
dict
=
{},
context
:
dict
=
{},
gt_condition
=
False
,
**
kwargs
):
'''
num_samples_batch: number of sample per batch (batch size)
'''
outs
=
{}
bbox_dict
=
self
.
det_net
(
context
=
context
)
bbox_dict
=
self
.
det_net
.
post_process
(
bbox_dict
)
outs
.
update
(
bbox_dict
)
if
len
(
outs
[
'lines_bs_idx'
])
==
0
:
return
None
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
outs
,
context
=
context
,
# max_sample_length=self.max_num_vertices,
max_sample_length
=
64
,
top_p
=
self
.
top_p_gen_model
,
gt_condition
=
gt_condition
)
outs
.
update
(
poly_dict
)
return
outs
def
post_process
(
self
,
preds
:
dict
,
tokens
,
gts
:
dict
=
None
,
**
kwargs
):
'''
Args:
XXX
Outs:
XXX
'''
range_size
=
self
.
gen_net
.
canvas_size
.
cpu
().
numpy
()
coord_dim
=
self
.
gen_net
.
coord_dim
gen_net_name
=
self
.
gen_net
.
name
if
hasattr
(
self
.
gen_net
,
'name'
)
else
'gen'
ret_list
=
[]
for
batch_idx
in
range
(
len
(
tokens
)):
ret_dict_single
=
{}
# bbox
det_gt
=
None
if
gts
is
not
None
:
det_gt
,
rec_groundtruth
=
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
,
coord_dim
=
coord_dim
)
bbox_res
=
{
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
'token'
:
tokens
[
batch_idx
],
'scores'
:
preds
[
'scores'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'labels'
:
preds
[
'labels'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
}
ret_dict_single
.
update
(
bbox_res
)
# for gen results.
batch2seq
=
np
.
nonzero
(
preds
[
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_dict_single
.
update
({
'nline'
:
len
(
batch2seq
),
'vectors'
:
[]
})
for
i
in
batch2seq
:
pre
=
preds
[
'polylines'
][
i
].
detach
().
cpu
().
numpy
()
pre_msk
=
preds
[
'polyline_masks'
][
i
].
detach
().
cpu
().
numpy
()
valid_idx
=
np
.
nonzero
(
pre_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
line
=
(
pre
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_dict_single
[
'vectors'
].
append
(
line
)
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list
.
append
(
ret_dict_single
)
return
ret_list
def
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
=
'gen'
,
coord_dim
=
2
):
if
'keypoints'
in
gts
[
'det'
]:
gt_bbox
=
\
gts
[
'det'
][
'keypoints'
][
batch_idx
].
detach
().
cpu
().
numpy
()
else
:
gt_bbox
=
\
gts
[
'det'
][
'bbox'
][
batch_idx
].
detach
().
cpu
().
numpy
()
det_gt
=
{
'labels'
:
gts
[
'det'
][
'class_label'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'bboxes'
:
gt_bbox
,
}
batch2seq
=
np
.
nonzero
(
gts
[
'gen'
][
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_groundtruth
=
{
'token'
:
tokens
[
batch_idx
],
'nline'
:
len
(
batch2seq
),
'labels'
:
gts
[
'gen'
][
'lines_cls'
][
batch2seq
].
detach
().
cpu
().
numpy
(),
'lines'
:
[],
}
for
i
in
batch2seq
:
gt_line
=
\
gts
[
'gen'
][
'polylines'
].
detach
().
cpu
().
numpy
()[
i
]
gt_msk
=
gts
[
'gen'
][
'polyline_masks'
].
detach
().
cpu
().
numpy
()[
i
]
if
gen_net_name
==
'gen_gmm'
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
]
else
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
line
=
(
gt_line
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_groundtruth
[
'lines'
].
append
(
line
)
return
det_gt
,
ret_groundtruth
import
copy
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
force_fp32
from
mmdet.models
import
HEADS
,
build_head
,
build_loss
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
.base_map_head
import
BaseMapHead
import
numpy
as
np
from
..augmentation.sythesis_det
import
NoiseSythesis
@
HEADS
.
register_module
(
force
=
True
)
class
DGHead
(
BaseMapHead
):
def
__init__
(
self
,
det_net_cfg
=
dict
(),
gen_net_cfg
=
dict
(),
loss_vert
=
dict
(),
loss_face
=
dict
(),
max_num_vertices
=
90
,
top_p_gen_model
=
0.9
,
sync_cls_avg_factor
=
True
,
augmentation
=
False
,
augmentation_kwargs
=
None
,
joint_training
=
False
,
**
kwargs
):
super
().
__init__
()
# Heads
self
.
det_net
=
build_head
(
det_net_cfg
)
self
.
gen_net
=
build_head
(
gen_net_cfg
)
self
.
coord_dim
=
self
.
gen_net
.
coord_dim
# Loss params
self
.
bg_cls_weight
=
1.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
max_num_vertices
=
max_num_vertices
self
.
top_p_gen_model
=
top_p_gen_model
self
.
fp16_enabled
=
False
self
.
augmentation
=
None
if
augmentation
:
augmentation_kwargs
.
update
({
'canvas_size'
:
gen_net_cfg
.
canvas_size
})
self
.
augmentation
=
NoiseSythesis
(
**
augmentation_kwargs
)
self
.
joint_training
=
joint_training
def
forward
(
self
,
batch
,
img_metas
=
None
,
**
kwargs
):
'''
Args:
Returns:
outs (Dict):
'''
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
only_det
=
False
,
**
kwargs
):
''' we use teacher force strategy'''
bbox_dict
=
self
.
det_net
(
context
=
context
)
outs
=
dict
(
bbox
=
bbox_dict
,
)
losses_dict
,
det_match_idxs
,
det_match_gt_idxs
=
\
self
.
loss_det
(
batch
,
outs
)
if
only_det
:
return
outs
,
losses_dict
if
self
.
augmentation
is
not
None
:
polylines
,
bbox_flat
=
\
self
.
augmentation
(
batch
[
'gen'
],
simple_aug
=
True
)
if
bbox_flat
is
None
:
bbox_flat
=
batch
[
'gen'
][
'bbox_flat'
]
gen_input
=
dict
(
lines_bs_idx
=
batch
[
'gen'
][
'lines_bs_idx'
],
lines_cls
=
batch
[
'gen'
][
'lines_cls'
],
bbox_flat
=
bbox_flat
,
polylines
=
polylines
,
polyline_masks
=
batch
[
'gen'
][
'polyline_masks'
]
)
else
:
gen_input
=
batch
[
'gen'
]
if
self
.
joint_training
:
# for down stream polyline
if
'lines'
in
bbox_dict
[
-
1
]:
# for fix anchor
pred_bbox
=
bbox_dict
[
-
1
][
'lines'
].
detach
()
elif
'bboxs'
in
bbox_dict
[
-
1
]:
# for rpv
pred_bbox
=
bbox_dict
[
-
1
][
'bboxs'
].
detach
()
else
:
raise
NotImplementedError
# changed to original gt order.
det_match_idx
=
det_match_idxs
[
-
1
]
det_match_gt_idx
=
det_match_gt_idxs
[
-
1
]
_bboxs
=
[]
for
i
,
(
match_idx
,
bbox
)
in
enumerate
(
zip
(
det_match_idx
,
pred_bbox
)):
_bboxs
.
append
(
bbox
[
match_idx
])
_bboxs
[
-
1
]
=
_bboxs
[
-
1
][
torch
.
argsort
(
det_match_gt_idx
[
i
])]
_bboxs
=
torch
.
cat
(
_bboxs
,
dim
=
0
)
# quantize the data
_bboxs
=
\
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
# gen_input['bbox_flat'] = _bboxs
remain_idx
=
torch
.
randperm
(
_bboxs
.
shape
[
0
])[:
int
(
_bboxs
.
shape
[
0
]
*
0.2
)]
# for data efficient
for
k
in
gen_input
.
keys
():
if
k
==
'bbox_flat'
:
gen_input
[
k
]
=
torch
.
cat
((
_bboxs
,
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
else
:
gen_input
[
k
]
=
torch
.
cat
((
gen_input
[
k
],
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
gen_input
,
context
=
context
)
outs
.
update
(
dict
(
polylines
=
poly_dict
,
))
if
self
.
joint_training
:
for
k
in
batch
[
'gen'
].
keys
():
batch
[
'gen'
][
k
]
=
\
torch
.
cat
((
batch
[
'gen'
][
k
],
batch
[
'gen'
][
k
][
remain_idx
]),
dim
=
0
)
gen_losses_dict
=
\
self
.
loss_gen
(
batch
,
outs
)
losses_dict
.
update
(
gen_losses_dict
)
return
outs
,
losses_dict
def
loss_det
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
# det
det_loss_dict
,
det_match_idx
,
det_match_gt_idx
=
\
self
.
det_net
.
loss
(
gt
[
'det'
],
pred
[
'bbox'
])
for
k
,
v
in
det_loss_dict
.
items
():
loss_dict
[
'det_'
+
k
]
=
v
return
loss_dict
,
det_match_idx
,
det_match_gt_idx
def
loss_gen
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
# gen
gen_loss_dict
=
self
.
gen_net
.
loss
(
gt
[
'gen'
],
pred
[
'polylines'
])
for
k
,
v
in
gen_loss_dict
.
items
():
loss_dict
[
'gen_'
+
k
]
=
v
return
loss_dict
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
pass
@
torch
.
no_grad
()
def
inference
(
self
,
batch
:
dict
=
{},
context
:
dict
=
{},
gt_condition
=
False
,
**
kwargs
):
'''
num_samples_batch: number of sample per batch (batch size)
'''
outs
=
{}
bbox_dict
=
self
.
det_net
(
context
=
context
)
bbox_dict
=
self
.
det_net
.
post_process
(
bbox_dict
)
outs
.
update
(
bbox_dict
)
if
len
(
outs
[
'lines_bs_idx'
])
==
0
:
return
None
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
outs
,
context
=
context
,
# max_sample_length=self.max_num_vertices,
max_sample_length
=
64
,
top_p
=
self
.
top_p_gen_model
,
gt_condition
=
gt_condition
)
outs
.
update
(
poly_dict
)
return
outs
def
post_process
(
self
,
preds
:
dict
,
tokens
,
gts
:
dict
=
None
,
**
kwargs
):
'''
Args:
XXX
Outs:
XXX
'''
range_size
=
self
.
gen_net
.
canvas_size
.
cpu
().
numpy
()
coord_dim
=
self
.
gen_net
.
coord_dim
gen_net_name
=
self
.
gen_net
.
name
if
hasattr
(
self
.
gen_net
,
'name'
)
else
'gen'
ret_list
=
[]
for
batch_idx
in
range
(
len
(
tokens
)):
ret_dict_single
=
{}
# bbox
det_gt
=
None
if
gts
is
not
None
:
det_gt
,
rec_groundtruth
=
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
,
coord_dim
=
coord_dim
)
bbox_res
=
{
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
'token'
:
tokens
[
batch_idx
],
'scores'
:
preds
[
'scores'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'labels'
:
preds
[
'labels'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
}
ret_dict_single
.
update
(
bbox_res
)
# for gen results.
batch2seq
=
np
.
nonzero
(
preds
[
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_dict_single
.
update
({
'nline'
:
len
(
batch2seq
),
'vectors'
:
[]
})
for
i
in
batch2seq
:
pre
=
preds
[
'polylines'
][
i
].
detach
().
cpu
().
numpy
()
pre_msk
=
preds
[
'polyline_masks'
][
i
].
detach
().
cpu
().
numpy
()
valid_idx
=
np
.
nonzero
(
pre_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
line
=
(
pre
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_dict_single
[
'vectors'
].
append
(
line
)
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list
.
append
(
ret_dict_single
)
return
ret_list
def
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
=
'gen'
,
coord_dim
=
2
):
if
'keypoints'
in
gts
[
'det'
]:
gt_bbox
=
\
gts
[
'det'
][
'keypoints'
][
batch_idx
].
detach
().
cpu
().
numpy
()
else
:
gt_bbox
=
\
gts
[
'det'
][
'bbox'
][
batch_idx
].
detach
().
cpu
().
numpy
()
det_gt
=
{
'labels'
:
gts
[
'det'
][
'class_label'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'bboxes'
:
gt_bbox
,
}
batch2seq
=
np
.
nonzero
(
gts
[
'gen'
][
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_groundtruth
=
{
'token'
:
tokens
[
batch_idx
],
'nline'
:
len
(
batch2seq
),
'labels'
:
gts
[
'gen'
][
'lines_cls'
][
batch2seq
].
detach
().
cpu
().
numpy
(),
'lines'
:
[],
}
for
i
in
batch2seq
:
gt_line
=
\
gts
[
'gen'
][
'polylines'
].
detach
().
cpu
().
numpy
()[
i
]
gt_msk
=
gts
[
'gen'
][
'polyline_masks'
].
detach
().
cpu
().
numpy
()[
i
]
if
gen_net_name
==
'gen_gmm'
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
]
else
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
line
=
(
gt_line
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_groundtruth
[
'lines'
].
append
(
line
)
return
det_gt
,
ret_groundtruth
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/map_element_detector.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/map_element_detector.py
View file @
f3b13cad
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Conv2d
,
Linear
from
mmcv.runner
import
force_fp32
from
torch.distributions.categorical
import
Categorical
from
mmdet.core
import
(
multi_apply
,
build_assigner
,
build_sampler
,
reduce_mean
)
from
mmdet.models
import
HEADS
from
.detr_bbox
import
DETRBboxHead
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
build_loss
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
bias_init_with_prob
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmdet.models.utils
import
build_transformer
@
HEADS
.
register_module
(
force
=
True
)
class
MapElementDetector
(
nn
.
Module
):
def
__init__
(
self
,
canvas_size
=
(
400
,
200
),
discrete_output
=
False
,
separate_detect
=
False
,
mode
=
'xyxy'
,
bbox_size
=
None
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
num_classes
=
3
,
in_channels
=
128
,
num_query
=
100
,
max_lines
=
50
,
score_thre
=
0.2
,
num_reg_fcs
=
2
,
num_points
=
100
,
iterative
=
False
,
patch_size
=
None
,
sync_cls_avg_factor
=
True
,
transformer
:
dict
=
None
,
positional_encoding
:
dict
=
None
,
loss_cls
:
dict
=
None
,
loss_reg
:
dict
=
None
,
train_cfg
:
dict
=
None
,):
super
().
__init__
()
assigner
=
train_cfg
[
'assigner'
]
self
.
assigner
=
build_assigner
(
assigner
)
# DETR sampling=False, so use PseudoSampler
sampler_cfg
=
dict
(
type
=
'PseudoSampler'
)
self
.
sampler
=
build_sampler
(
sampler_cfg
,
context
=
self
)
self
.
train_cfg
=
train_cfg
self
.
max_lines
=
max_lines
self
.
score_thre
=
score_thre
self
.
num_query
=
num_query
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
num_points
=
num_points
# branch
# if loss_cls.use_sigmoid:
if
loss_cls
[
'use_sigmoid'
]:
self
.
cls_out_channels
=
num_classes
else
:
self
.
cls_out_channels
=
num_classes
+
1
self
.
iterative
=
iterative
self
.
num_reg_fcs
=
num_reg_fcs
self
.
_build_transformer
(
transformer
,
positional_encoding
)
# loss params
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
bg_cls_weight
=
0.1
if
self
.
loss_cls
.
use_sigmoid
:
self
.
bg_cls_weight
=
0.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
reg_loss
=
build_loss
(
loss_reg
)
self
.
separate_detect
=
separate_detect
self
.
discrete_output
=
discrete_output
self
.
bbox_size
=
3
if
mode
==
'sce'
else
2
if
bbox_size
is
not
None
:
self
.
bbox_size
=
bbox_size
self
.
coord_dim
=
coord_dim
# for xyz
self
.
kp_coord_dim
=
kp_coord_dim
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# add reg, cls head for each decoder layer
self
.
_init_layers
()
self
.
_init_branch
()
self
.
init_weights
()
self
.
_init_embedding
()
def
_init_layers
(
self
):
"""Initialize some layer."""
self
.
input_proj
=
Conv2d
(
self
.
in_channels
,
self
.
embed_dims
,
kernel_size
=
1
)
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
)
def
_init_embedding
(
self
):
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embed_dims
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embed_dims
)
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
# for bbox parameter xstart, ystart, xend, yend
self
.
bbox_embedding
=
nn
.
Embedding
(
self
.
bbox_size
,
self
.
embed_dims
*
2
)
def
_init_branch
(
self
,):
"""Initialize classification branch and regression branch of head."""
fc_cls
=
Linear
(
self
.
embed_dims
*
self
.
bbox_size
,
self
.
cls_out_channels
)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
if
self
.
discrete_output
:
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
max
(
self
.
canvas_size
),
bias
=
True
,))
else
:
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
,
bias
=
True
,))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
# add sigmoid or not
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
num_pred
=
self
.
transformer
.
decoder
.
num_layers
if
self
.
iterative
:
fc_cls
=
_get_clones
(
fc_cls
,
num_pred
)
reg_branch
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
reg_branch
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
fc_cls
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
pre_branches
=
nn
.
ModuleDict
([
(
'cls'
,
fc_cls
),
(
'reg'
,
reg_branch
),
])
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
for
p
in
self
.
input_proj
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
self
.
transformer
.
init_weights
()
# init prediction branch
for
k
,
v
in
self
.
pre_branches
.
items
():
for
param
in
v
.
parameters
():
if
param
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
param
)
# focal loss init
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
# for last layer
if
isinstance
(
self
.
pre_branches
[
'cls'
],
nn
.
ModuleList
):
for
m
in
self
.
pre_branches
[
'cls'
]:
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
else
:
m
=
self
.
pre_branches
[
'cls'
]
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
def
_build_transformer
(
self
,
transformer
,
positional_encoding
):
# transformer
self
.
act_cfg
=
transformer
.
get
(
'act_cfg'
,
dict
(
type
=
'ReLU'
,
inplace
=
True
))
self
.
activate
=
build_activation_layer
(
self
.
act_cfg
)
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
embed_dims
=
self
.
transformer
.
embed_dims
def
_prepare_context
(
self
,
context
):
"""Prepare class label and vertex context."""
global_context_embedding
=
None
image_embeddings
=
context
[
'bev_embeddings'
]
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
# only change feature size
# Pass images through encoder
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
B
,
C
,
H
,
W
=
image_embeddings
.
shape
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
C
,
H
,
W
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
def
forward
(
self
,
context
,
img_metas
=
None
,
multi_scale
=
False
):
'''
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
img_metas
Outs:
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
'''
(
global_context_embedding
,
sequential_context_embeddings
)
=
\
self
.
_prepare_context
(
context
)
x
=
sequential_context_embeddings
B
,
C
,
H
,
W
=
x
.
shape
query_embedding
=
self
.
query_embedding
.
weight
[
None
,:,
None
].
repeat
(
B
,
1
,
self
.
bbox_size
,
1
)
bbox_embed
=
self
.
bbox_embedding
.
weight
query_embedding
=
query_embedding
+
bbox_embed
[
None
,
None
]
query_embedding
=
query_embedding
.
view
(
B
,
-
1
,
C
*
2
)
img_masks
=
x
.
new_zeros
((
B
,
H
,
W
))
pos_embed
=
self
.
positional_encoding
(
img_masks
)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs
,
init_reference
,
inter_references
=
self
.
transformer
(
[
x
,],
[
img_masks
.
type
(
torch
.
bool
)],
query_embedding
,
[
pos_embed
],
reg_branches
=
self
.
reg_branches
if
self
.
iterative
else
None
,
# noqa:E501
cls_branches
=
None
,
# noqa:E501
)
outs_dec
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs
=
[]
for
i
,
(
query_feat
)
in
enumerate
(
outs_dec
):
if
i
==
0
:
reference
=
init_reference
else
:
reference
=
inter_references
[
i
-
1
]
outputs
.
append
(
self
.
get_prediction
(
i
,
query_feat
,
reference
))
return
outputs
def
get_prediction
(
self
,
level
,
query_feat
,
reference
):
bs
,
num_query
,
h
=
query_feat
.
shape
query_feat
=
query_feat
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
h
)
ocls
=
self
.
pre_branches
[
'cls'
][
level
](
query_feat
.
flatten
(
-
2
))
# ocls = ocls.mean(-2)
reference
=
inverse_sigmoid
(
reference
)
reference
=
reference
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
self
.
coord_dim
)
tmp
=
self
.
pre_branches
[
'reg'
][
level
](
query_feat
)
tmp
[...,:
self
.
kp_coord_dim
]
=
tmp
[...,:
self
.
kp_coord_dim
]
+
reference
[...,:
self
.
kp_coord_dim
]
lines
=
tmp
.
sigmoid
()
# bs, num_query, self.bbox_size,2
lines
=
lines
*
self
.
canvas_size
[:
self
.
coord_dim
]
lines
=
lines
.
flatten
(
-
2
)
return
dict
(
lines
=
lines
,
# [bs, num_query, bboxsize*2]
scores
=
ocls
,
# [bs, num_query, num_class]
embeddings
=
query_feat
,
# [bs, num_query, bbox_size, h]
)
@
force_fp32
(
apply_to
=
(
'score_pred'
,
'lines_pred'
,
'gt_lines'
))
def
_get_target_single
(
self
,
score_pred
,
lines_pred
,
gt_labels
,
gt_lines
,
gt_bboxes_ignore
=
None
):
"""
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
shape [num_query, num_points, 2].
gt_lines (Tensor):
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
shape [num_gt, ]
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_pred_lines
=
len
(
lines_pred
)
# assigner and sampler
assign_result
=
self
.
assigner
.
assign
(
preds
=
dict
(
lines
=
lines_pred
,
scores
=
score_pred
,),
gts
=
dict
(
lines
=
gt_lines
,
labels
=
gt_labels
,
),
gt_bboxes_ignore
=
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
lines_pred
,
gt_lines
)
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
pos_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
# label targets 0: foreground, 1: background
if
self
.
separate_detect
:
labels
=
gt_lines
.
new_full
((
num_pred_lines
,
),
1
,
dtype
=
torch
.
long
)
else
:
labels
=
gt_lines
.
new_full
(
(
num_pred_lines
,
),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
label_weights
=
gt_lines
.
new_ones
(
num_pred_lines
)
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
if
self
.
discrete_output
:
lines_target
=
torch
.
zeros_like
(
lines_pred
[...,
0
]).
long
()
lines_weights
=
torch
.
zeros_like
(
lines_pred
[...,
0
])
else
:
lines_target
=
torch
.
zeros_like
(
lines_pred
)
lines_weights
=
torch
.
zeros_like
(
lines_pred
)
lines_target
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
.
type
(
lines_target
.
dtype
)
lines_weights
[
pos_inds
]
=
1.0
n
=
lines_weights
.
sum
(
-
1
,
keepdim
=
True
)
lines_weights
=
lines_weights
/
n
.
masked_fill
(
n
==
0
,
1
)
return
(
labels
,
label_weights
,
lines_target
,
lines_weights
,
pos_inds
,
neg_inds
,
pos_gt_inds
)
# @force_fp32(apply_to=('preds', 'gts'))
def
get_targets
(
self
,
preds
,
gts
,
gt_bboxes_ignore_list
=
None
):
"""
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- lines_targets_list (list[Tensor]): Lines targets for all \
images.
- lines_weight_list (list[Tensor]): Lines weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
if
self
.
separate_detect
:
bbox
=
[
b
[
m
]
for
b
,
m
in
zip
(
gts
[
'bbox'
],
gts
[
'bbox_mask'
])]
class_label
=
torch
.
zeros_like
(
gts
[
'bbox_mask'
]).
long
()
class_label
=
[
b
[
m
]
for
b
,
m
in
zip
(
class_label
,
gts
[
'bbox_mask'
])]
else
:
class_label
=
gts
[
'class_label'
]
bbox
=
gts
[
'bbox'
]
if
self
.
discrete_output
:
lines_pred
=
preds
[
'lines'
].
logits
else
:
lines_pred
=
preds
[
'lines'
]
bbox
=
[
b
.
float
()
for
b
in
bbox
]
(
labels_list
,
label_weights_list
,
lines_targets_list
,
lines_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_gt_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
preds
[
'scores'
],
lines_pred
,
class_label
,
bbox
,
gt_bboxes_ignore
=
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
new_gts
=
dict
(
labels
=
labels_list
,
label_weights
=
label_weights_list
,
bboxs
=
lines_targets_list
,
bboxs_weights
=
lines_weights_list
,
)
return
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
def
loss_single
(
self
,
preds
:
dict
,
gts
:
dict
,
gt_bboxes_ignore_list
=
None
,
reduction
=
'none'
):
"""
Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
# Get target for each sample
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
=
\
self
.
get_targets
(
preds
,
gts
,
gt_bboxes_ignore_list
)
# Batched all data
for
k
,
v
in
new_gts
.
items
():
new_gts
[
k
]
=
torch
.
stack
(
v
,
dim
=
0
)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
preds
[
'scores'
].
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
# Classification loss
if
self
.
separate_detect
:
loss_cls
=
self
.
bce_loss
(
preds
[
'scores'
],
new_gts
[
'labels'
],
new_gts
[
'label_weights'
],
cls_avg_factor
)
else
:
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores
=
preds
[
'scores'
].
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_labels
=
new_gts
[
'labels'
].
reshape
(
-
1
)
cls_weights
=
new_gts
[
'label_weights'
].
reshape
(
-
1
)
loss_cls
=
self
.
loss_cls
(
cls_scores
,
cls_labels
,
cls_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# position NLL loss
if
self
.
discrete_output
:
loss_reg
=
-
(
preds
[
'lines'
].
log_prob
(
new_gts
[
'bboxs'
])
*
new_gts
[
'bboxs_weights'
]).
sum
()
/
(
num_total_pos
)
else
:
loss_reg
=
self
.
reg_loss
(
preds
[
'lines'
],
new_gts
[
'bboxs'
],
new_gts
[
'bboxs_weights'
],
avg_factor
=
num_total_pos
)
loss_dict
=
dict
(
cls
=
loss_cls
,
reg
=
loss_reg
,
)
return
loss_dict
,
pos_inds_list
,
pos_gt_inds_list
@
force_fp32
(
apply_to
=
(
'gt_lines_list'
,
'preds_dicts'
))
def
loss
(
self
,
gts
:
dict
,
preds_dicts
:
dict
,
gt_bboxes_ignore
=
None
,
reduction
=
'mean'
):
"""
Loss Function.
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
losses
,
pos_inds_lists
,
pos_gt_inds_lists
=
multi_apply
(
self
.
loss_single
,
preds_dicts
,
gts
=
gts
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
reduction
=
reduction
)
# Format the losses
loss_dict
=
dict
()
# loss from the last decoder layer
for
k
,
v
in
losses
[
-
1
].
items
():
loss_dict
[
k
]
=
v
# Loss from other decoder layers
num_dec_layer
=
0
for
loss
in
losses
[:
-
1
]:
for
k
,
v
in
loss
.
items
():
loss_dict
[
f
'd
{
num_dec_layer
}
.
{
k
}
'
]
=
v
num_dec_layer
+=
1
return
loss_dict
,
pos_inds_lists
,
pos_gt_inds_lists
def
post_process
(
self
,
preds_dicts
:
list
,
**
kwargs
):
'''
Args:
preds_dicts:
scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
XXX
'''
preds
=
preds_dicts
[
-
1
]
batched_cls_scores
=
preds
[
'scores'
]
batched_lines_preds
=
preds
[
'lines'
]
batch_size
=
batched_cls_scores
.
size
(
0
)
device
=
batched_cls_scores
.
device
result_dict
=
{
'bbox'
:
[],
'scores'
:
[],
'labels'
:
[],
'bbox_flat'
:
[],
'lines_cls'
:
[],
'lines_bs_idx'
:
[],
}
for
i
in
range
(
batch_size
):
cls_scores
=
batched_cls_scores
[
i
]
det_preds
=
batched_lines_preds
[
i
]
max_num
=
self
.
max_lines
if
self
.
loss_cls
.
use_sigmoid
:
cls_scores
=
cls_scores
.
sigmoid
()
scores
,
valid_idx
=
cls_scores
.
view
(
-
1
).
topk
(
max_num
)
det_labels
=
valid_idx
%
self
.
num_classes
valid_idx
=
valid_idx
//
self
.
num_classes
det_preds
=
det_preds
[
valid_idx
]
else
:
scores
,
det_labels
=
F
.
softmax
(
cls_scores
,
dim
=-
1
)[...,
:
-
1
].
max
(
-
1
)
scores
,
valid_idx
=
scores
.
topk
(
max_num
)
det_preds
=
det_preds
[
valid_idx
]
det_labels
=
det_labels
[
valid_idx
]
nline
=
len
(
valid_idx
)
result_dict
[
'bbox'
].
append
(
det_preds
)
result_dict
[
'scores'
].
append
(
scores
)
result_dict
[
'labels'
].
append
(
det_labels
)
result_dict
[
'lines_bs_idx'
].
extend
([
i
]
*
nline
)
# for down stream polyline
_bboxs
=
torch
.
cat
(
result_dict
[
'bbox'
],
dim
=
0
)
# quantize the data
result_dict
[
'bbox_flat'
]
=
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
result_dict
[
'lines_cls'
]
=
torch
.
cat
(
result_dict
[
'labels'
],
dim
=
0
).
long
()
result_dict
[
'lines_bs_idx'
]
=
torch
.
tensor
(
result_dict
[
'lines_bs_idx'
],
device
=
device
).
long
()
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Conv2d
,
Linear
from
mmcv.runner
import
force_fp32
from
torch.distributions.categorical
import
Categorical
from
mmdet.core
import
(
multi_apply
,
build_assigner
,
build_sampler
,
reduce_mean
)
from
mmdet.models
import
HEADS
from
.detr_bbox
import
DETRBboxHead
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
build_loss
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
bias_init_with_prob
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmdet.models.utils
import
build_transformer
@
HEADS
.
register_module
(
force
=
True
)
class
MapElementDetector
(
nn
.
Module
):
def
__init__
(
self
,
canvas_size
=
(
400
,
200
),
discrete_output
=
False
,
separate_detect
=
False
,
mode
=
'xyxy'
,
bbox_size
=
None
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
num_classes
=
3
,
in_channels
=
128
,
num_query
=
100
,
max_lines
=
50
,
score_thre
=
0.2
,
num_reg_fcs
=
2
,
num_points
=
100
,
iterative
=
False
,
patch_size
=
None
,
sync_cls_avg_factor
=
True
,
transformer
:
dict
=
None
,
positional_encoding
:
dict
=
None
,
loss_cls
:
dict
=
None
,
loss_reg
:
dict
=
None
,
train_cfg
:
dict
=
None
,):
super
().
__init__
()
assigner
=
train_cfg
[
'assigner'
]
self
.
assigner
=
build_assigner
(
assigner
)
# DETR sampling=False, so use PseudoSampler
sampler_cfg
=
dict
(
type
=
'PseudoSampler'
)
self
.
sampler
=
build_sampler
(
sampler_cfg
,
context
=
self
)
self
.
train_cfg
=
train_cfg
self
.
max_lines
=
max_lines
self
.
score_thre
=
score_thre
self
.
num_query
=
num_query
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
num_points
=
num_points
# branch
# if loss_cls.use_sigmoid:
if
loss_cls
[
'use_sigmoid'
]:
self
.
cls_out_channels
=
num_classes
else
:
self
.
cls_out_channels
=
num_classes
+
1
self
.
iterative
=
iterative
self
.
num_reg_fcs
=
num_reg_fcs
self
.
_build_transformer
(
transformer
,
positional_encoding
)
# loss params
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
bg_cls_weight
=
0.1
if
self
.
loss_cls
.
use_sigmoid
:
self
.
bg_cls_weight
=
0.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
reg_loss
=
build_loss
(
loss_reg
)
self
.
separate_detect
=
separate_detect
self
.
discrete_output
=
discrete_output
self
.
bbox_size
=
3
if
mode
==
'sce'
else
2
if
bbox_size
is
not
None
:
self
.
bbox_size
=
bbox_size
self
.
coord_dim
=
coord_dim
# for xyz
self
.
kp_coord_dim
=
kp_coord_dim
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# add reg, cls head for each decoder layer
self
.
_init_layers
()
self
.
_init_branch
()
self
.
init_weights
()
self
.
_init_embedding
()
def
_init_layers
(
self
):
"""Initialize some layer."""
self
.
input_proj
=
Conv2d
(
self
.
in_channels
,
self
.
embed_dims
,
kernel_size
=
1
)
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
)
def
_init_embedding
(
self
):
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embed_dims
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embed_dims
)
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
# for bbox parameter xstart, ystart, xend, yend
self
.
bbox_embedding
=
nn
.
Embedding
(
self
.
bbox_size
,
self
.
embed_dims
*
2
)
def
_init_branch
(
self
,):
"""Initialize classification branch and regression branch of head."""
fc_cls
=
Linear
(
self
.
embed_dims
*
self
.
bbox_size
,
self
.
cls_out_channels
)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
if
self
.
discrete_output
:
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
max
(
self
.
canvas_size
),
bias
=
True
,))
else
:
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
,
bias
=
True
,))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
# add sigmoid or not
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
num_pred
=
self
.
transformer
.
decoder
.
num_layers
if
self
.
iterative
:
fc_cls
=
_get_clones
(
fc_cls
,
num_pred
)
reg_branch
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
reg_branch
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
fc_cls
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
pre_branches
=
nn
.
ModuleDict
([
(
'cls'
,
fc_cls
),
(
'reg'
,
reg_branch
),
])
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
for
p
in
self
.
input_proj
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
self
.
transformer
.
init_weights
()
# init prediction branch
for
k
,
v
in
self
.
pre_branches
.
items
():
for
param
in
v
.
parameters
():
if
param
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
param
)
# focal loss init
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
# for last layer
if
isinstance
(
self
.
pre_branches
[
'cls'
],
nn
.
ModuleList
):
for
m
in
self
.
pre_branches
[
'cls'
]:
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
else
:
m
=
self
.
pre_branches
[
'cls'
]
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
def
_build_transformer
(
self
,
transformer
,
positional_encoding
):
# transformer
self
.
act_cfg
=
transformer
.
get
(
'act_cfg'
,
dict
(
type
=
'ReLU'
,
inplace
=
True
))
self
.
activate
=
build_activation_layer
(
self
.
act_cfg
)
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
embed_dims
=
self
.
transformer
.
embed_dims
def
_prepare_context
(
self
,
context
):
"""Prepare class label and vertex context."""
global_context_embedding
=
None
image_embeddings
=
context
[
'bev_embeddings'
]
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
# only change feature size
# Pass images through encoder
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
B
,
C
,
H
,
W
=
image_embeddings
.
shape
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
C
,
H
,
W
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
def
forward
(
self
,
context
,
img_metas
=
None
,
multi_scale
=
False
):
'''
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
img_metas
Outs:
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
'''
(
global_context_embedding
,
sequential_context_embeddings
)
=
\
self
.
_prepare_context
(
context
)
x
=
sequential_context_embeddings
B
,
C
,
H
,
W
=
x
.
shape
query_embedding
=
self
.
query_embedding
.
weight
[
None
,:,
None
].
repeat
(
B
,
1
,
self
.
bbox_size
,
1
)
bbox_embed
=
self
.
bbox_embedding
.
weight
query_embedding
=
query_embedding
+
bbox_embed
[
None
,
None
]
query_embedding
=
query_embedding
.
view
(
B
,
-
1
,
C
*
2
)
img_masks
=
x
.
new_zeros
((
B
,
H
,
W
))
pos_embed
=
self
.
positional_encoding
(
img_masks
)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs
,
init_reference
,
inter_references
=
self
.
transformer
(
[
x
,],
[
img_masks
.
type
(
torch
.
bool
)],
query_embedding
,
[
pos_embed
],
reg_branches
=
self
.
reg_branches
if
self
.
iterative
else
None
,
# noqa:E501
cls_branches
=
None
,
# noqa:E501
)
outs_dec
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs
=
[]
for
i
,
(
query_feat
)
in
enumerate
(
outs_dec
):
if
i
==
0
:
reference
=
init_reference
else
:
reference
=
inter_references
[
i
-
1
]
outputs
.
append
(
self
.
get_prediction
(
i
,
query_feat
,
reference
))
return
outputs
def
get_prediction
(
self
,
level
,
query_feat
,
reference
):
bs
,
num_query
,
h
=
query_feat
.
shape
query_feat
=
query_feat
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
h
)
ocls
=
self
.
pre_branches
[
'cls'
][
level
](
query_feat
.
flatten
(
-
2
))
# ocls = ocls.mean(-2)
reference
=
inverse_sigmoid
(
reference
)
reference
=
reference
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
self
.
coord_dim
)
tmp
=
self
.
pre_branches
[
'reg'
][
level
](
query_feat
)
tmp
[...,:
self
.
kp_coord_dim
]
=
tmp
[...,:
self
.
kp_coord_dim
]
+
reference
[...,:
self
.
kp_coord_dim
]
lines
=
tmp
.
sigmoid
()
# bs, num_query, self.bbox_size,2
lines
=
lines
*
self
.
canvas_size
[:
self
.
coord_dim
]
lines
=
lines
.
flatten
(
-
2
)
return
dict
(
lines
=
lines
,
# [bs, num_query, bboxsize*2]
scores
=
ocls
,
# [bs, num_query, num_class]
embeddings
=
query_feat
,
# [bs, num_query, bbox_size, h]
)
@
force_fp32
(
apply_to
=
(
'score_pred'
,
'lines_pred'
,
'gt_lines'
))
def
_get_target_single
(
self
,
score_pred
,
lines_pred
,
gt_labels
,
gt_lines
,
gt_bboxes_ignore
=
None
):
"""
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
shape [num_query, num_points, 2].
gt_lines (Tensor):
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
shape [num_gt, ]
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_pred_lines
=
len
(
lines_pred
)
# assigner and sampler
assign_result
=
self
.
assigner
.
assign
(
preds
=
dict
(
lines
=
lines_pred
,
scores
=
score_pred
,),
gts
=
dict
(
lines
=
gt_lines
,
labels
=
gt_labels
,
),
gt_bboxes_ignore
=
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
lines_pred
,
gt_lines
)
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
pos_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
# label targets 0: foreground, 1: background
if
self
.
separate_detect
:
labels
=
gt_lines
.
new_full
((
num_pred_lines
,
),
1
,
dtype
=
torch
.
long
)
else
:
labels
=
gt_lines
.
new_full
(
(
num_pred_lines
,
),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
label_weights
=
gt_lines
.
new_ones
(
num_pred_lines
)
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
if
self
.
discrete_output
:
lines_target
=
torch
.
zeros_like
(
lines_pred
[...,
0
]).
long
()
lines_weights
=
torch
.
zeros_like
(
lines_pred
[...,
0
])
else
:
lines_target
=
torch
.
zeros_like
(
lines_pred
)
lines_weights
=
torch
.
zeros_like
(
lines_pred
)
lines_target
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
.
type
(
lines_target
.
dtype
)
lines_weights
[
pos_inds
]
=
1.0
n
=
lines_weights
.
sum
(
-
1
,
keepdim
=
True
)
lines_weights
=
lines_weights
/
n
.
masked_fill
(
n
==
0
,
1
)
return
(
labels
,
label_weights
,
lines_target
,
lines_weights
,
pos_inds
,
neg_inds
,
pos_gt_inds
)
# @force_fp32(apply_to=('preds', 'gts'))
def
get_targets
(
self
,
preds
,
gts
,
gt_bboxes_ignore_list
=
None
):
"""
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all
\
images.
- lines_targets_list (list[Tensor]): Lines targets for all
\
images.
- lines_weight_list (list[Tensor]): Lines weights for all
\
images.
- num_total_pos (int): Number of positive samples in all
\
images.
- num_total_neg (int): Number of negative samples in all
\
images.
"""
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
if
self
.
separate_detect
:
bbox
=
[
b
[
m
]
for
b
,
m
in
zip
(
gts
[
'bbox'
],
gts
[
'bbox_mask'
])]
class_label
=
torch
.
zeros_like
(
gts
[
'bbox_mask'
]).
long
()
class_label
=
[
b
[
m
]
for
b
,
m
in
zip
(
class_label
,
gts
[
'bbox_mask'
])]
else
:
class_label
=
gts
[
'class_label'
]
bbox
=
gts
[
'bbox'
]
if
self
.
discrete_output
:
lines_pred
=
preds
[
'lines'
].
logits
else
:
lines_pred
=
preds
[
'lines'
]
bbox
=
[
b
.
float
()
for
b
in
bbox
]
(
labels_list
,
label_weights_list
,
lines_targets_list
,
lines_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_gt_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
preds
[
'scores'
],
lines_pred
,
class_label
,
bbox
,
gt_bboxes_ignore
=
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
new_gts
=
dict
(
labels
=
labels_list
,
label_weights
=
label_weights_list
,
bboxs
=
lines_targets_list
,
bboxs_weights
=
lines_weights_list
,
)
return
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
def
loss_single
(
self
,
preds
:
dict
,
gts
:
dict
,
gt_bboxes_ignore_list
=
None
,
reduction
=
'none'
):
"""
Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
# Get target for each sample
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
=
\
self
.
get_targets
(
preds
,
gts
,
gt_bboxes_ignore_list
)
# Batched all data
for
k
,
v
in
new_gts
.
items
():
new_gts
[
k
]
=
torch
.
stack
(
v
,
dim
=
0
)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
preds
[
'scores'
].
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
# Classification loss
if
self
.
separate_detect
:
loss_cls
=
self
.
bce_loss
(
preds
[
'scores'
],
new_gts
[
'labels'
],
new_gts
[
'label_weights'
],
cls_avg_factor
)
else
:
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores
=
preds
[
'scores'
].
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_labels
=
new_gts
[
'labels'
].
reshape
(
-
1
)
cls_weights
=
new_gts
[
'label_weights'
].
reshape
(
-
1
)
loss_cls
=
self
.
loss_cls
(
cls_scores
,
cls_labels
,
cls_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# position NLL loss
if
self
.
discrete_output
:
loss_reg
=
-
(
preds
[
'lines'
].
log_prob
(
new_gts
[
'bboxs'
])
*
new_gts
[
'bboxs_weights'
]).
sum
()
/
(
num_total_pos
)
else
:
loss_reg
=
self
.
reg_loss
(
preds
[
'lines'
],
new_gts
[
'bboxs'
],
new_gts
[
'bboxs_weights'
],
avg_factor
=
num_total_pos
)
loss_dict
=
dict
(
cls
=
loss_cls
,
reg
=
loss_reg
,
)
return
loss_dict
,
pos_inds_list
,
pos_gt_inds_list
@
force_fp32
(
apply_to
=
(
'gt_lines_list'
,
'preds_dicts'
))
def
loss
(
self
,
gts
:
dict
,
preds_dicts
:
dict
,
gt_bboxes_ignore
=
None
,
reduction
=
'mean'
):
"""
Loss Function.
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
losses
,
pos_inds_lists
,
pos_gt_inds_lists
=
multi_apply
(
self
.
loss_single
,
preds_dicts
,
gts
=
gts
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
reduction
=
reduction
)
# Format the losses
loss_dict
=
dict
()
# loss from the last decoder layer
for
k
,
v
in
losses
[
-
1
].
items
():
loss_dict
[
k
]
=
v
# Loss from other decoder layers
num_dec_layer
=
0
for
loss
in
losses
[:
-
1
]:
for
k
,
v
in
loss
.
items
():
loss_dict
[
f
'd
{
num_dec_layer
}
.
{
k
}
'
]
=
v
num_dec_layer
+=
1
return
loss_dict
,
pos_inds_lists
,
pos_gt_inds_lists
def
post_process
(
self
,
preds_dicts
:
list
,
**
kwargs
):
'''
Args:
preds_dicts:
scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
XXX
'''
preds
=
preds_dicts
[
-
1
]
batched_cls_scores
=
preds
[
'scores'
]
batched_lines_preds
=
preds
[
'lines'
]
batch_size
=
batched_cls_scores
.
size
(
0
)
device
=
batched_cls_scores
.
device
result_dict
=
{
'bbox'
:
[],
'scores'
:
[],
'labels'
:
[],
'bbox_flat'
:
[],
'lines_cls'
:
[],
'lines_bs_idx'
:
[],
}
for
i
in
range
(
batch_size
):
cls_scores
=
batched_cls_scores
[
i
]
det_preds
=
batched_lines_preds
[
i
]
max_num
=
self
.
max_lines
if
self
.
loss_cls
.
use_sigmoid
:
cls_scores
=
cls_scores
.
sigmoid
()
scores
,
valid_idx
=
cls_scores
.
view
(
-
1
).
topk
(
max_num
)
det_labels
=
valid_idx
%
self
.
num_classes
valid_idx
=
valid_idx
//
self
.
num_classes
det_preds
=
det_preds
[
valid_idx
]
else
:
scores
,
det_labels
=
F
.
softmax
(
cls_scores
,
dim
=-
1
)[...,
:
-
1
].
max
(
-
1
)
scores
,
valid_idx
=
scores
.
topk
(
max_num
)
det_preds
=
det_preds
[
valid_idx
]
det_labels
=
det_labels
[
valid_idx
]
nline
=
len
(
valid_idx
)
result_dict
[
'bbox'
].
append
(
det_preds
)
result_dict
[
'scores'
].
append
(
scores
)
result_dict
[
'labels'
].
append
(
det_labels
)
result_dict
[
'lines_bs_idx'
].
extend
([
i
]
*
nline
)
# for down stream polyline
_bboxs
=
torch
.
cat
(
result_dict
[
'bbox'
],
dim
=
0
)
# quantize the data
result_dict
[
'bbox_flat'
]
=
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
result_dict
[
'lines_cls'
]
=
torch
.
cat
(
result_dict
[
'labels'
],
dim
=
0
).
long
()
result_dict
[
'lines_bs_idx'
]
=
torch
.
tensor
(
result_dict
[
'lines_bs_idx'
],
device
=
device
).
long
()
return
result_dict
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/polyline_generator.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/polyline_generator.py
View file @
f3b13cad
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributions.categorical
import
Categorical
from
mmdet.models
import
HEADS
from
.detgen_utils.causal_trans
import
(
CausalTransformerDecoder
,
CausalTransformerDecoderLayer
)
from
.detgen_utils.utils
import
(
dequantize_verts
,
generate_square_subsequent_mask
,
quantize_verts
,
top_k_logits
,
top_p_logits
)
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
HEADS
.
register_module
(
force
=
True
)
class
PolylineGenerator
(
nn
.
Module
):
"""
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def
__init__
(
self
,
in_channels
,
encoder_config
,
decoder_config
,
class_conditional
=
True
,
num_classes
=
55
,
decoder_cross_attention
=
True
,
use_discrete_vertex_embeddings
=
True
,
condition_points_num
=
3
,
coord_dim
=
2
,
canvas_size
=
(
400
,
200
),
max_seq_length
=
500
,
name
=
'gen_model'
):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super
(
PolylineGenerator
,
self
).
__init__
()
self
.
embedding_dim
=
decoder_config
[
'layer_config'
][
'd_model'
]
self
.
class_conditional
=
class_conditional
self
.
num_classes
=
num_classes
self
.
max_seq_length
=
max_seq_length
self
.
decoder_cross_attention
=
decoder_cross_attention
self
.
use_discrete_vertex_embeddings
=
use_discrete_vertex_embeddings
self
.
condition_points_num
=
condition_points_num
self
.
fp16_enabled
=
False
self
.
coord_dim
=
coord_dim
# if we use xyz else 2 when we use xy
self
.
kp_coord_dim
=
coord_dim
if
coord_dim
==
2
else
2
# XXX
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# initialize the model
self
.
_project_to_logits
=
nn
.
Linear
(
self
.
embedding_dim
,
max
(
canvas_size
)
+
1
,
# + 1 for stopping token. use_bias=True,
)
self
.
input_proj
=
nn
.
Conv2d
(
in_channels
,
self
.
embedding_dim
,
kernel_size
=
1
)
decoder_layer
=
CausalTransformerDecoderLayer
(
**
decoder_config
.
pop
(
'layer_config'
))
self
.
decoder
=
CausalTransformerDecoder
(
decoder_layer
,
**
decoder_config
)
self
.
_init_embedding
()
self
.
init_weights
()
def
_init_embedding
(
self
):
if
self
.
class_conditional
:
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embedding_dim
)
self
.
coord_embed
=
nn
.
Embedding
(
self
.
coord_dim
,
self
.
embedding_dim
)
self
.
pos_embeddings
=
nn
.
Embedding
(
self
.
max_seq_length
,
self
.
embedding_dim
)
# to indicate the role of the position is the start of the line or the end of it.
self
.
bbox_context_embed
=
\
nn
.
Embedding
(
self
.
condition_points_num
,
self
.
embedding_dim
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embedding_dim
)
# initialize the verteices embedding
if
self
.
use_discrete_vertex_embeddings
:
self
.
vertex_embed
=
nn
.
Embedding
(
max
(
self
.
canvas_size
)
+
1
,
self
.
embedding_dim
)
else
:
self
.
vertex_embed
=
nn
.
Linear
(
1
,
self
.
embedding_dim
)
def
init_weights
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
def
_embed_kps
(
self
,
bbox
):
bbox_len
=
bbox
.
shape
[
-
1
]
# Bbox_context
bbox_embedding
=
self
.
bbox_context_embed
(
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
/
self
.
kp_coord_dim
).
floor
().
long
())
# Coord indicators (x, y)
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
%
self
.
kp_coord_dim
)
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
bbox
)
return
vert_embeddings
+
(
bbox_embedding
+
coord_embeddings
)[
None
]
def
_prepare_context
(
self
,
batch
,
context
):
"""Prepare class label and vertex context."""
global_context_embedding
=
None
if
self
.
class_conditional
:
global_context_embedding
=
self
.
label_embed
(
batch
[
'lines_cls'
])
bbox_embeddings
=
self
.
_embed_kps
(
batch
[
'bbox_flat'
])
if
global_context_embedding
is
not
None
:
global_context_embedding
=
torch
.
cat
(
[
global_context_embedding
[:,
None
],
bbox_embeddings
],
dim
=
1
)
# Pass images through encoder
image_embeddings
=
assign_bev
(
context
[
'bev_embeddings'
],
batch
[
'lines_bs_idx'
])
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
H
,
W
=
image_embeddings
.
shape
[
2
:]
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
B
=
image_embeddings
.
shape
[
0
]
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
self
.
embedding_dim
,
-
1
).
permute
(
0
,
2
,
1
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
def
_embed_inputs
(
self
,
seqs
,
condition_embedding
=
None
):
"""Embeds face sequences and adds within and between face positions.
Args:
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
embeddings: B, seqlen, h
"""
B
,
seq_len
=
seqs
.
shape
[:
2
]
# Position embeddings
pos_embeddings
=
self
.
pos_embeddings
(
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
/
self
.
coord_dim
).
floor
().
long
())
# seq_len, h
# Coord indicators (x, y, z(optional))
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
%
self
.
coord_dim
)
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
seqs
)
# Aggregate embeddings
embeddings
=
vert_embeddings
+
\
(
coord_embeddings
+
pos_embeddings
)[
None
]
embeddings
=
torch
.
cat
([
condition_embedding
,
embeddings
],
dim
=
1
)
return
embeddings
def
forward
(
self
,
batch
:
dict
,
**
kwargs
):
"""
Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
def
sperate_forward
(
self
,
batch
,
context
,
**
kwargs
):
polyline_length
=
batch
[
'polyline_masks'
].
sum
(
-
1
)
c1
,
c2
,
revert_idx
,
size
=
get_chunk_idx
(
polyline_length
)
sizes
=
[
size
,
polyline_length
.
max
()]
polyline_logits
=
[]
for
c_idx
,
size
in
zip
([
c1
,
c2
],
sizes
):
new_batch
=
assign_batch
(
batch
,
c_idx
,
size
)
_poly_logits
=
self
.
_forward_train
(
new_batch
,
context
,
**
kwargs
)
polyline_logits
.
append
(
_poly_logits
)
# maybe imporve the speed
for
i
,
(
_poly_logits
,
size
)
in
enumerate
(
zip
(
polyline_logits
,
sizes
)):
if
size
<
sizes
[
1
]:
_poly_logits
=
F
.
pad
(
_poly_logits
,
(
0
,
0
,
0
,
sizes
[
1
]
-
size
),
"constant"
,
0
)
polyline_logits
[
i
]
=
_poly_logits
polyline_logits
=
torch
.
cat
(
polyline_logits
,
0
)
polyline_logits
=
polyline_logits
[
revert_idx
]
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
if
False
:
polyline_logits
=
self
.
_forward_train
(
batch
,
context
,
**
kwargs
)
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
else
:
return
self
.
sperate_forward
(
batch
,
context
,
**
kwargs
)
def
_forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
logits
=
self
.
body
(
# Last element not used for preds
batch
[
'polylines'
][:,
:
-
1
],
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
return_logits
=
True
,
is_training
=
self
.
training
)
return
logits
@
force_fp32
(
apply_to
=
(
'global_context_embedding'
,
'sequential_context_embeddings'
,
'cache'
))
def
body
(
self
,
seqs
,
global_context_embedding
=
None
,
sequential_context_embeddings
=
None
,
temperature
=
1.
,
top_k
=
0
,
top_p
=
1.
,
cache
=
None
,
return_logits
=
False
,
is_training
=
True
):
"""
Outputs categorical dist for vertex indices.
Body of the face model
"""
# Embed inputs
condition_len
=
global_context_embedding
.
shape
[
1
]
decoder_inputs
=
self
.
_embed_inputs
(
seqs
,
global_context_embedding
)
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
decoder_inputs
=
decoder_inputs
.
transpose
(
0
,
1
)
if
sequential_context_embeddings
is
not
None
:
sequential_context_embeddings
=
sequential_context_embeddings
.
transpose
(
0
,
1
)
causal_msk
=
None
if
is_training
:
causal_msk
=
generate_square_subsequent_mask
(
decoder_inputs
.
shape
[
0
],
condition_len
=
condition_len
,
device
=
decoder_inputs
.
device
)
decoder_outputs
,
cache
=
self
.
decoder
(
tgt
=
decoder_inputs
,
cache
=
cache
,
memory
=
sequential_context_embeddings
,
causal_mask
=
causal_msk
,
)
decoder_outputs
=
decoder_outputs
.
transpose
(
0
,
1
)
# since we only need the predict seq
decoder_outputs
=
decoder_outputs
[:,
condition_len
-
1
:]
# Get logits and optionally process for sampling
logits
=
self
.
_project_to_logits
(
decoder_outputs
)
# y mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_y
=
(
_vert_mask
<
self
.
canvas_size
[
1
]
+
1
)
vertices_mask_y
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
1
::
self
.
coord_dim
]
=
logits
[:,
1
::
self
.
coord_dim
]
*
\
vertices_mask_y
-
~
vertices_mask_y
*
1e9
if
self
.
coord_dim
>
2
:
# z mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_z
=
(
_vert_mask
<
self
.
canvas_size
[
2
]
+
1
)
vertices_mask_z
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
2
::
self
.
coord_dim
]
=
logits
[:,
2
::
self
.
coord_dim
]
*
\
vertices_mask_z
-
~
vertices_mask_z
*
1e9
logits
=
logits
/
temperature
logits
=
top_k_logits
(
logits
,
top_k
)
logits
=
top_p_logits
(
logits
,
top_p
)
if
return_logits
:
return
logits
cat_dist
=
Categorical
(
logits
=
logits
)
return
cat_dist
,
cache
@
force_fp32
(
apply_to
=
(
'pred'
))
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
weight
=
gt
[
'polyline_weights'
]
mask
=
gt
[
'polyline_masks'
]
loss
=
-
torch
.
sum
(
pred
[
'polylines'
].
log_prob
(
gt
[
'polylines'
])
*
mask
*
weight
)
/
weight
.
sum
()
return
{
'seq'
:
loss
}
def
inference
(
self
,
batch
:
dict
,
context
:
dict
,
max_sample_length
=
None
,
temperature
=
1.
,
top_k
=
0
,
top_p
=
1.
,
only_return_complete
=
False
,
gt_condition
=
False
,
**
kwargs
):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
# prepare the conditional variable
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
device
=
global_context
.
device
batch_size
=
global_context
.
shape
[
0
]
# While loop sampling with caching
samples
=
torch
.
empty
(
[
batch_size
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
max_sample_length
=
max_sample_length
or
self
.
max_seq_length
seq_len
=
max_sample_length
*
self
.
coord_dim
+
1
cache
=
None
decoded_tokens
=
\
torch
.
zeros
((
batch_size
,
seq_len
),
device
=
device
,
dtype
=
torch
.
long
)
remain_idx
=
torch
.
arange
(
batch_size
,
device
=
device
)
for
i
in
range
(
seq_len
):
# While-loop body for autoregression calculation.
pred_dist
,
cache
=
self
.
body
(
samples
,
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
cache
=
cache
,
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
is_training
=
False
)
samples
=
pred_dist
.
sample
()
decoded_tokens
[
remain_idx
,
i
]
=
samples
[:,
-
1
]
# Stopping conditions for autoregressive calculation.
if
not
(
decoded_tokens
[:,:
i
+
1
]
!=
0
).
all
(
-
1
).
any
():
break
# update state, check the new position is zero.
valid_idx
=
(
samples
[:,
-
1
]
!=
0
).
nonzero
(
as_tuple
=
True
)[
0
]
remain_idx
=
remain_idx
[
valid_idx
]
cache
=
cache
[:,:,
valid_idx
]
global_context
=
global_context
[
valid_idx
]
seq_context
=
seq_context
[
valid_idx
]
samples
=
samples
[
valid_idx
]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens
=
decoded_tokens
[:,:
i
+
1
]
outputs
=
self
.
post_process
(
decoded_tokens
,
seq_len
,
device
,
only_return_complete
)
return
outputs
def
post_process
(
self
,
polyline
,
max_seq_len
=
None
,
device
=
None
,
only_return_complete
=
True
):
'''
format the predictions
find the mask
'''
# Record completed samples
complete_samples
=
(
polyline
==
0
).
any
(
-
1
)
# Find number of faces
sample_seq_length
=
polyline
.
shape
[
-
1
]
_polyline_mask
=
torch
.
arange
(
sample_seq_length
)[
None
].
to
(
device
)
# Get largest stopping point for incomplete samples.
valid_polyline_len
=
torch
.
full_like
(
polyline
[:,
0
],
sample_seq_length
)
zero_inds
=
(
polyline
==
0
).
type
(
torch
.
int32
).
argmax
(
-
1
)
# Real length
valid_polyline_len
[
complete_samples
]
=
zero_inds
[
complete_samples
]
+
1
polyline_mask
=
_polyline_mask
<
valid_polyline_len
[:,
None
]
# Mask faces beyond stopping token with zeros
polyline
=
polyline
*
polyline_mask
# Pad to maximum size with zeros
pad_size
=
max_seq_len
-
sample_seq_length
polyline
=
F
.
pad
(
polyline
,
(
0
,
pad_size
))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# if only_return_complete:
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# outputs
outputs
=
{
'completed'
:
complete_samples
,
'polylines'
:
polyline
,
'polyline_masks'
:
polyline_mask
,
}
return
outputs
def
find_best_sperate_plan
(
idx
,
array
):
h
=
array
[
-
1
]
-
array
[
idx
]
w
=
idx
cost
=
h
*
w
return
cost
def
get_chunk_idx
(
polyline_length
):
_polyline_length
,
polyline_length_idx
=
torch
.
sort
(
polyline_length
)
costs
=
[]
for
i
in
range
(
len
(
_polyline_length
)):
cost
=
find_best_sperate_plan
(
i
,
_polyline_length
)
costs
.
append
(
cost
)
seperate_point
=
torch
.
stack
(
costs
).
argmax
()
chunk1
=
polyline_length_idx
[:
seperate_point
+
1
]
chunk2
=
polyline_length_idx
[
seperate_point
+
1
:]
revert_idx
=
torch
.
argsort
(
polyline_length_idx
)
return
chunk1
,
chunk2
,
revert_idx
,
_polyline_length
[
seperate_point
]
def
assign_bev
(
feat
,
idx
):
return
feat
[
idx
]
def
assign_batch
(
batch
,
idx
,
size
):
new_batch
=
{}
for
k
,
v
in
batch
.
items
():
new_batch
[
k
]
=
v
[
idx
]
if
new_batch
[
k
].
ndim
>
1
:
new_batch
[
k
]
=
new_batch
[
k
][:,:
size
]
return
new_batch
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributions.categorical
import
Categorical
from
mmdet.models
import
HEADS
from
.detgen_utils.causal_trans
import
(
CausalTransformerDecoder
,
CausalTransformerDecoderLayer
)
from
.detgen_utils.utils
import
(
dequantize_verts
,
generate_square_subsequent_mask
,
quantize_verts
,
top_k_logits
,
top_p_logits
)
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
HEADS
.
register_module
(
force
=
True
)
class
PolylineGenerator
(
nn
.
Module
):
"""
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def
__init__
(
self
,
in_channels
,
encoder_config
,
decoder_config
,
class_conditional
=
True
,
num_classes
=
55
,
decoder_cross_attention
=
True
,
use_discrete_vertex_embeddings
=
True
,
condition_points_num
=
3
,
coord_dim
=
2
,
canvas_size
=
(
400
,
200
),
max_seq_length
=
500
,
name
=
'gen_model'
):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super
(
PolylineGenerator
,
self
).
__init__
()
self
.
embedding_dim
=
decoder_config
[
'layer_config'
][
'd_model'
]
self
.
class_conditional
=
class_conditional
self
.
num_classes
=
num_classes
self
.
max_seq_length
=
max_seq_length
self
.
decoder_cross_attention
=
decoder_cross_attention
self
.
use_discrete_vertex_embeddings
=
use_discrete_vertex_embeddings
self
.
condition_points_num
=
condition_points_num
self
.
fp16_enabled
=
False
self
.
coord_dim
=
coord_dim
# if we use xyz else 2 when we use xy
self
.
kp_coord_dim
=
coord_dim
if
coord_dim
==
2
else
2
# XXX
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# initialize the model
self
.
_project_to_logits
=
nn
.
Linear
(
self
.
embedding_dim
,
max
(
canvas_size
)
+
1
,
# + 1 for stopping token. use_bias=True,
)
self
.
input_proj
=
nn
.
Conv2d
(
in_channels
,
self
.
embedding_dim
,
kernel_size
=
1
)
decoder_layer
=
CausalTransformerDecoderLayer
(
**
decoder_config
.
pop
(
'layer_config'
))
self
.
decoder
=
CausalTransformerDecoder
(
decoder_layer
,
**
decoder_config
)
self
.
_init_embedding
()
self
.
init_weights
()
def
_init_embedding
(
self
):
if
self
.
class_conditional
:
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embedding_dim
)
self
.
coord_embed
=
nn
.
Embedding
(
self
.
coord_dim
,
self
.
embedding_dim
)
self
.
pos_embeddings
=
nn
.
Embedding
(
self
.
max_seq_length
,
self
.
embedding_dim
)
# to indicate the role of the position is the start of the line or the end of it.
self
.
bbox_context_embed
=
\
nn
.
Embedding
(
self
.
condition_points_num
,
self
.
embedding_dim
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embedding_dim
)
# initialize the verteices embedding
if
self
.
use_discrete_vertex_embeddings
:
self
.
vertex_embed
=
nn
.
Embedding
(
max
(
self
.
canvas_size
)
+
1
,
self
.
embedding_dim
)
else
:
self
.
vertex_embed
=
nn
.
Linear
(
1
,
self
.
embedding_dim
)
def
init_weights
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
def
_embed_kps
(
self
,
bbox
):
bbox_len
=
bbox
.
shape
[
-
1
]
# Bbox_context
bbox_embedding
=
self
.
bbox_context_embed
(
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
/
self
.
kp_coord_dim
).
floor
().
long
())
# Coord indicators (x, y)
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
%
self
.
kp_coord_dim
)
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
bbox
)
return
vert_embeddings
+
(
bbox_embedding
+
coord_embeddings
)[
None
]
def
_prepare_context
(
self
,
batch
,
context
):
"""Prepare class label and vertex context."""
global_context_embedding
=
None
if
self
.
class_conditional
:
global_context_embedding
=
self
.
label_embed
(
batch
[
'lines_cls'
])
bbox_embeddings
=
self
.
_embed_kps
(
batch
[
'bbox_flat'
])
if
global_context_embedding
is
not
None
:
global_context_embedding
=
torch
.
cat
(
[
global_context_embedding
[:,
None
],
bbox_embeddings
],
dim
=
1
)
# Pass images through encoder
image_embeddings
=
assign_bev
(
context
[
'bev_embeddings'
],
batch
[
'lines_bs_idx'
])
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
H
,
W
=
image_embeddings
.
shape
[
2
:]
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
B
=
image_embeddings
.
shape
[
0
]
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
self
.
embedding_dim
,
-
1
).
permute
(
0
,
2
,
1
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
def
_embed_inputs
(
self
,
seqs
,
condition_embedding
=
None
):
"""Embeds face sequences and adds within and between face positions.
Args:
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
embeddings: B, seqlen, h
"""
B
,
seq_len
=
seqs
.
shape
[:
2
]
# Position embeddings
pos_embeddings
=
self
.
pos_embeddings
(
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
/
self
.
coord_dim
).
floor
().
long
())
# seq_len, h
# Coord indicators (x, y, z(optional))
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
%
self
.
coord_dim
)
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
seqs
)
# Aggregate embeddings
embeddings
=
vert_embeddings
+
\
(
coord_embeddings
+
pos_embeddings
)[
None
]
embeddings
=
torch
.
cat
([
condition_embedding
,
embeddings
],
dim
=
1
)
return
embeddings
def
forward
(
self
,
batch
:
dict
,
**
kwargs
):
"""
Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
def
sperate_forward
(
self
,
batch
,
context
,
**
kwargs
):
polyline_length
=
batch
[
'polyline_masks'
].
sum
(
-
1
)
c1
,
c2
,
revert_idx
,
size
=
get_chunk_idx
(
polyline_length
)
sizes
=
[
size
,
polyline_length
.
max
()]
polyline_logits
=
[]
for
c_idx
,
size
in
zip
([
c1
,
c2
],
sizes
):
new_batch
=
assign_batch
(
batch
,
c_idx
,
size
)
_poly_logits
=
self
.
_forward_train
(
new_batch
,
context
,
**
kwargs
)
polyline_logits
.
append
(
_poly_logits
)
# maybe imporve the speed
for
i
,
(
_poly_logits
,
size
)
in
enumerate
(
zip
(
polyline_logits
,
sizes
)):
if
size
<
sizes
[
1
]:
_poly_logits
=
F
.
pad
(
_poly_logits
,
(
0
,
0
,
0
,
sizes
[
1
]
-
size
),
"constant"
,
0
)
polyline_logits
[
i
]
=
_poly_logits
polyline_logits
=
torch
.
cat
(
polyline_logits
,
0
)
polyline_logits
=
polyline_logits
[
revert_idx
]
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
if
False
:
polyline_logits
=
self
.
_forward_train
(
batch
,
context
,
**
kwargs
)
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
else
:
return
self
.
sperate_forward
(
batch
,
context
,
**
kwargs
)
def
_forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
logits
=
self
.
body
(
# Last element not used for preds
batch
[
'polylines'
][:,
:
-
1
],
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
return_logits
=
True
,
is_training
=
self
.
training
)
return
logits
@
force_fp32
(
apply_to
=
(
'global_context_embedding'
,
'sequential_context_embeddings'
,
'cache'
))
def
body
(
self
,
seqs
,
global_context_embedding
=
None
,
sequential_context_embeddings
=
None
,
temperature
=
1.
,
top_k
=
0
,
top_p
=
1.
,
cache
=
None
,
return_logits
=
False
,
is_training
=
True
):
"""
Outputs categorical dist for vertex indices.
Body of the face model
"""
# Embed inputs
condition_len
=
global_context_embedding
.
shape
[
1
]
decoder_inputs
=
self
.
_embed_inputs
(
seqs
,
global_context_embedding
)
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
decoder_inputs
=
decoder_inputs
.
transpose
(
0
,
1
)
if
sequential_context_embeddings
is
not
None
:
sequential_context_embeddings
=
sequential_context_embeddings
.
transpose
(
0
,
1
)
causal_msk
=
None
if
is_training
:
causal_msk
=
generate_square_subsequent_mask
(
decoder_inputs
.
shape
[
0
],
condition_len
=
condition_len
,
device
=
decoder_inputs
.
device
)
decoder_outputs
,
cache
=
self
.
decoder
(
tgt
=
decoder_inputs
,
cache
=
cache
,
memory
=
sequential_context_embeddings
,
causal_mask
=
causal_msk
,
)
decoder_outputs
=
decoder_outputs
.
transpose
(
0
,
1
)
# since we only need the predict seq
decoder_outputs
=
decoder_outputs
[:,
condition_len
-
1
:]
# Get logits and optionally process for sampling
logits
=
self
.
_project_to_logits
(
decoder_outputs
)
# y mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_y
=
(
_vert_mask
<
self
.
canvas_size
[
1
]
+
1
)
vertices_mask_y
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
1
::
self
.
coord_dim
]
=
logits
[:,
1
::
self
.
coord_dim
]
*
\
vertices_mask_y
-
~
vertices_mask_y
*
1e9
if
self
.
coord_dim
>
2
:
# z mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_z
=
(
_vert_mask
<
self
.
canvas_size
[
2
]
+
1
)
vertices_mask_z
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
2
::
self
.
coord_dim
]
=
logits
[:,
2
::
self
.
coord_dim
]
*
\
vertices_mask_z
-
~
vertices_mask_z
*
1e9
logits
=
logits
/
temperature
logits
=
top_k_logits
(
logits
,
top_k
)
logits
=
top_p_logits
(
logits
,
top_p
)
if
return_logits
:
return
logits
cat_dist
=
Categorical
(
logits
=
logits
)
return
cat_dist
,
cache
@
force_fp32
(
apply_to
=
(
'pred'
))
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
weight
=
gt
[
'polyline_weights'
]
mask
=
gt
[
'polyline_masks'
]
loss
=
-
torch
.
sum
(
pred
[
'polylines'
].
log_prob
(
gt
[
'polylines'
])
*
mask
*
weight
)
/
weight
.
sum
()
return
{
'seq'
:
loss
}
def
inference
(
self
,
batch
:
dict
,
context
:
dict
,
max_sample_length
=
None
,
temperature
=
1.
,
top_k
=
0
,
top_p
=
1.
,
only_return_complete
=
False
,
gt_condition
=
False
,
**
kwargs
):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
# prepare the conditional variable
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
device
=
global_context
.
device
batch_size
=
global_context
.
shape
[
0
]
# While loop sampling with caching
samples
=
torch
.
empty
(
[
batch_size
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
max_sample_length
=
max_sample_length
or
self
.
max_seq_length
seq_len
=
max_sample_length
*
self
.
coord_dim
+
1
cache
=
None
decoded_tokens
=
\
torch
.
zeros
((
batch_size
,
seq_len
),
device
=
device
,
dtype
=
torch
.
long
)
remain_idx
=
torch
.
arange
(
batch_size
,
device
=
device
)
for
i
in
range
(
seq_len
):
# While-loop body for autoregression calculation.
pred_dist
,
cache
=
self
.
body
(
samples
,
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
cache
=
cache
,
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
is_training
=
False
)
samples
=
pred_dist
.
sample
()
decoded_tokens
[
remain_idx
,
i
]
=
samples
[:,
-
1
]
# Stopping conditions for autoregressive calculation.
if
not
(
decoded_tokens
[:,:
i
+
1
]
!=
0
).
all
(
-
1
).
any
():
break
# update state, check the new position is zero.
valid_idx
=
(
samples
[:,
-
1
]
!=
0
).
nonzero
(
as_tuple
=
True
)[
0
]
remain_idx
=
remain_idx
[
valid_idx
]
cache
=
cache
[:,:,
valid_idx
]
global_context
=
global_context
[
valid_idx
]
seq_context
=
seq_context
[
valid_idx
]
samples
=
samples
[
valid_idx
]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens
=
decoded_tokens
[:,:
i
+
1
]
outputs
=
self
.
post_process
(
decoded_tokens
,
seq_len
,
device
,
only_return_complete
)
return
outputs
def
post_process
(
self
,
polyline
,
max_seq_len
=
None
,
device
=
None
,
only_return_complete
=
True
):
'''
format the predictions
find the mask
'''
# Record completed samples
complete_samples
=
(
polyline
==
0
).
any
(
-
1
)
# Find number of faces
sample_seq_length
=
polyline
.
shape
[
-
1
]
_polyline_mask
=
torch
.
arange
(
sample_seq_length
)[
None
].
to
(
device
)
# Get largest stopping point for incomplete samples.
valid_polyline_len
=
torch
.
full_like
(
polyline
[:,
0
],
sample_seq_length
)
zero_inds
=
(
polyline
==
0
).
type
(
torch
.
int32
).
argmax
(
-
1
)
# Real length
valid_polyline_len
[
complete_samples
]
=
zero_inds
[
complete_samples
]
+
1
polyline_mask
=
_polyline_mask
<
valid_polyline_len
[:,
None
]
# Mask faces beyond stopping token with zeros
polyline
=
polyline
*
polyline_mask
# Pad to maximum size with zeros
pad_size
=
max_seq_len
-
sample_seq_length
polyline
=
F
.
pad
(
polyline
,
(
0
,
pad_size
))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# if only_return_complete:
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# outputs
outputs
=
{
'completed'
:
complete_samples
,
'polylines'
:
polyline
,
'polyline_masks'
:
polyline_mask
,
}
return
outputs
def
find_best_sperate_plan
(
idx
,
array
):
h
=
array
[
-
1
]
-
array
[
idx
]
w
=
idx
cost
=
h
*
w
return
cost
def
get_chunk_idx
(
polyline_length
):
_polyline_length
,
polyline_length_idx
=
torch
.
sort
(
polyline_length
)
costs
=
[]
for
i
in
range
(
len
(
_polyline_length
)):
cost
=
find_best_sperate_plan
(
i
,
_polyline_length
)
costs
.
append
(
cost
)
seperate_point
=
torch
.
stack
(
costs
).
argmax
()
chunk1
=
polyline_length_idx
[:
seperate_point
+
1
]
chunk2
=
polyline_length_idx
[
seperate_point
+
1
:]
revert_idx
=
torch
.
argsort
(
polyline_length_idx
)
return
chunk1
,
chunk2
,
revert_idx
,
_polyline_length
[
seperate_point
]
def
assign_bev
(
feat
,
idx
):
return
feat
[
idx
]
def
assign_batch
(
batch
,
idx
,
size
):
new_batch
=
{}
for
k
,
v
in
batch
.
items
():
new_batch
[
k
]
=
v
[
idx
]
if
new_batch
[
k
].
ndim
>
1
:
new_batch
[
k
]
=
new_batch
[
k
][:,:
size
]
return
new_batch
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/losses/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/losses/__init__.py
View file @
f3b13cad
from
.detr_loss
import
LinesLoss
,
MasksLoss
,
LenLoss
from
.detr_loss
import
LinesLoss
,
MasksLoss
,
LenLoss
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/losses/detr_loss.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/losses/detr_loss.py
View file @
f3b13cad
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
mmdet.models.losses
import
l1_loss
from
mmdet.models.losses.utils
import
weighted_loss
import
mmcv
from
mmdet.models.builder
import
LOSSES
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
smooth_l1_loss
(
pred
,
target
,
beta
=
1.0
):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert
beta
>
0
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
target
.
size
()
diff
=
torch
.
abs
(
pred
-
target
)
loss
=
torch
.
where
(
diff
<
beta
,
0.5
*
diff
*
diff
/
beta
,
diff
-
0.5
*
beta
)
return
loss
@
LOSSES
.
register_module
()
class
LinesLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
,
beta
=
0.5
):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super
(
LinesLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
beta
=
beta
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
smooth_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
beta
=
self
.
beta
)
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
bce
(
pred
,
label
,
class_weight
=
None
):
"""
pred: B,nquery,npts
label: B,nquery,npts
"""
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
label
.
size
()
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
return
loss
@
LOSSES
.
register_module
()
class
MasksLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
MasksLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
xxx
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
bce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
ce
(
pred
,
label
,
class_weight
=
None
):
"""
pred: B*nquery,npts
label: B*nquery,
"""
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
loss
=
F
.
cross_entropy
(
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
)
return
loss
@
LOSSES
.
register_module
()
class
LenLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
LenLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
xxx
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
ce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
mmdet.models.losses
import
l1_loss
from
mmdet.models.losses.utils
import
weighted_loss
import
mmcv
from
mmdet.models.builder
import
LOSSES
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
smooth_l1_loss
(
pred
,
target
,
beta
=
1.0
):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert
beta
>
0
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
target
.
size
()
diff
=
torch
.
abs
(
pred
-
target
)
loss
=
torch
.
where
(
diff
<
beta
,
0.5
*
diff
*
diff
/
beta
,
diff
-
0.5
*
beta
)
return
loss
@
LOSSES
.
register_module
()
class
LinesLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
,
beta
=
0.5
):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super
(
LinesLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
beta
=
beta
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
smooth_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
beta
=
self
.
beta
)
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
bce
(
pred
,
label
,
class_weight
=
None
):
"""
pred: B,nquery,npts
label: B,nquery,npts
"""
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
label
.
size
()
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
return
loss
@
LOSSES
.
register_module
()
class
MasksLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
MasksLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
xxx
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
bce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
ce
(
pred
,
label
,
class_weight
=
None
):
"""
pred: B*nquery,npts
label: B*nquery,
"""
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
loss
=
F
.
cross_entropy
(
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
)
return
loss
@
LOSSES
.
register_module
()
class
LenLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
LenLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
xxx
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
ce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
*
self
.
loss_weight
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/__init__.py
View file @
f3b13cad
from
.vectormapnet
import
VectorMapNet
from
.vectormapnet
import
VectorMapNet
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/base_mapper.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/base_mapper.py
View file @
f3b13cad
from
abc
import
ABCMeta
,
abstractmethod
import
torch.nn
as
nn
from
mmcv.runner
import
auto_fp16
from
mmcv.utils
import
print_log
from
mmdet.utils
import
get_root_logger
from
mmdet3d.models.builder
import
DETECTORS
MAPPERS
=
DETECTORS
class
BaseMapper
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""Base class for mappers."""
def
__init__
(
self
):
super
(
BaseMapper
,
self
).
__init__
()
self
.
fp16_enabled
=
False
@
property
def
with_neck
(
self
):
"""bool: whether the detector has a neck"""
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@
property
def
with_shared_head
(
self
):
"""bool: whether the detector has a shared head in the RoI Head"""
return
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_shared_head
@
property
def
with_bbox
(
self
):
"""bool: whether the detector has a bbox head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_bbox
)
or
(
hasattr
(
self
,
'bbox_head'
)
and
self
.
bbox_head
is
not
None
))
@
property
def
with_mask
(
self
):
"""bool: whether the detector has a mask head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_mask
)
or
(
hasattr
(
self
,
'mask_head'
)
and
self
.
mask_head
is
not
None
))
#@abstractmethod
def
extract_feat
(
self
,
imgs
):
"""Extract features from images."""
pass
def
forward_train
(
self
,
*
args
,
**
kwargs
):
pass
#@abstractmethod
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
pass
#@abstractmethod
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
"""Test function with test time augmentation."""
pass
def
init_weights
(
self
,
pretrained
=
None
):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if
pretrained
is
not
None
:
logger
=
get_root_logger
()
print_log
(
f
'load model from:
{
pretrained
}
'
,
logger
=
logger
)
def
forward_test
(
self
,
*
args
,
**
kwargs
):
"""
Args:
"""
if
True
:
self
.
simple_test
()
else
:
self
.
aug_test
()
# @auto_fp16(apply_to=('img', ))
def
forward
(
self
,
*
args
,
return_loss
=
True
,
**
kwargs
):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
kwargs
.
pop
(
'rescale'
)
return
self
.
forward_test
(
*
args
,
**
kwargs
)
def
train_step
(
self
,
data_dict
,
optimizer
):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data_dict
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
def
val_step
(
self
,
data
,
optimizer
):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
def
show_result
(
self
,
**
kwargs
):
img
=
None
from
abc
import
ABCMeta
,
abstractmethod
import
torch.nn
as
nn
from
mmcv.runner
import
auto_fp16
from
mmcv.utils
import
print_log
from
mmdet.utils
import
get_root_logger
from
mmdet3d.models.builder
import
DETECTORS
MAPPERS
=
DETECTORS
class
BaseMapper
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""Base class for mappers."""
def
__init__
(
self
):
super
(
BaseMapper
,
self
).
__init__
()
self
.
fp16_enabled
=
False
@
property
def
with_neck
(
self
):
"""bool: whether the detector has a neck"""
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@
property
def
with_shared_head
(
self
):
"""bool: whether the detector has a shared head in the RoI Head"""
return
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_shared_head
@
property
def
with_bbox
(
self
):
"""bool: whether the detector has a bbox head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_bbox
)
or
(
hasattr
(
self
,
'bbox_head'
)
and
self
.
bbox_head
is
not
None
))
@
property
def
with_mask
(
self
):
"""bool: whether the detector has a mask head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_mask
)
or
(
hasattr
(
self
,
'mask_head'
)
and
self
.
mask_head
is
not
None
))
#@abstractmethod
def
extract_feat
(
self
,
imgs
):
"""Extract features from images."""
pass
def
forward_train
(
self
,
*
args
,
**
kwargs
):
pass
#@abstractmethod
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
pass
#@abstractmethod
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
"""Test function with test time augmentation."""
pass
def
init_weights
(
self
,
pretrained
=
None
):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if
pretrained
is
not
None
:
logger
=
get_root_logger
()
print_log
(
f
'load model from:
{
pretrained
}
'
,
logger
=
logger
)
def
forward_test
(
self
,
*
args
,
**
kwargs
):
"""
Args:
"""
if
True
:
self
.
simple_test
()
else
:
self
.
aug_test
()
# @auto_fp16(apply_to=('img', ))
def
forward
(
self
,
*
args
,
return_loss
=
True
,
**
kwargs
):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
kwargs
.
pop
(
'rescale'
)
return
self
.
forward_test
(
*
args
,
**
kwargs
)
def
train_step
(
self
,
data_dict
,
optimizer
):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
\
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a
\
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is
\
DDP, it means the batch size on each GPU), which is used for
\
averaging the logs.
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data_dict
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
def
val_step
(
self
,
data
,
optimizer
):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
def
show_result
(
self
,
**
kwargs
):
img
=
None
return
img
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/vectormapnet.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/vectormapnet.py
View file @
f3b13cad
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pad_sequence
from
torchvision.models.resnet
import
resnet18
,
resnet50
from
mmdet3d.models.builder
import
(
build_backbone
,
build_head
,
build_neck
)
from
.base_mapper
import
BaseMapper
,
MAPPERS
@
MAPPERS
.
register_module
()
class
VectorMapNet
(
BaseMapper
):
def
__init__
(
self
,
backbone_cfg
=
dict
(),
head_cfg
=
dict
(
vert_net_cfg
=
dict
(),
face_net_cfg
=
dict
(),
),
neck_input_channels
=
128
,
neck_cfg
=
None
,
with_auxiliary_head
=
False
,
only_det
=
False
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
model_name
=
None
,
**
kwargs
):
super
(
VectorMapNet
,
self
).
__init__
()
#Attribute
self
.
model_name
=
model_name
self
.
last_epoch
=
None
self
.
only_det
=
only_det
self
.
backbone
=
build_backbone
(
backbone_cfg
)
if
neck_cfg
is
not
None
:
self
.
neck_neck
=
build_backbone
(
neck_cfg
.
backbone
)
self
.
neck_neck
.
conv1
=
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
neck_project
=
build_neck
(
neck_cfg
.
neck
)
self
.
neck
=
self
.
multiscale_neck
else
:
trunk
=
resnet18
(
pretrained
=
False
,
zero_init_residual
=
True
)
self
.
neck
=
nn
.
Sequential
(
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
(
7
,
7
),
stride
=
(
2
,
2
),
padding
=
(
3
,
3
),
bias
=
False
),
nn
.
BatchNorm2d
(
64
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
dilation
=
1
,
ceil_mode
=
False
),
trunk
.
layer1
,
nn
.
Conv2d
(
64
,
128
,
kernel_size
=
1
,
bias
=
False
),
)
# BEV
if
hasattr
(
self
.
backbone
,
'bev_w'
):
self
.
bev_w
=
self
.
backbone
.
bev_w
self
.
bev_h
=
self
.
backbone
.
bev_h
self
.
head
=
build_head
(
head_cfg
)
def
multiscale_neck
(
self
,
bev_embedding
):
multi_feat
=
self
.
neck_neck
(
bev_embedding
)
multi_feat
=
self
.
neck_project
(
multi_feat
)
return
multi_feat
def
forward_train
(
self
,
img
,
polys
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- length: int
- label: int
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
Out:
loss, log_vars, num_sample
'''
# prepare labels and images
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
batch_data
(
polys
,
img
,
img_metas
,
img
.
device
,
points
)
# corner cases use hard code to prevent code fail
if
self
.
last_epoch
is
None
:
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
if
len
(
valid_idx
)
==
0
:
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
last_epoch
else
:
del
self
.
last_epoch
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
# Backbone
_bev_feats
=
self
.
backbone
(
img
,
img_metas
=
img_metas
,
points
=
points
)
img_shape
=
\
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
preds_dict
,
losses_dict
=
\
self
.
head
(
batch
,
context
=
{
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
'raw_bev_embeddings'
:
_bev_feats
},
only_det
=
self
.
only_det
)
# format outputs
loss
=
0
for
name
,
var
in
losses_dict
.
items
():
loss
=
loss
+
var
# update the log
log_vars
=
{
k
:
v
.
item
()
for
k
,
v
in
losses_dict
.
items
()}
log_vars
.
update
({
'total'
:
loss
.
item
()})
num_sample
=
img
.
size
(
0
)
return
loss
,
log_vars
,
num_sample
@
torch
.
no_grad
()
def
forward_test
(
self
,
img
,
polys
=
None
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
inference pipeline
'''
# prepare labels and images
token
=
[]
for
img_meta
in
img_metas
:
token
.
append
(
img_meta
[
'token'
])
_bev_feats
=
self
.
backbone
(
img
,
img_metas
,
points
=
points
)
img_shape
=
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
context
=
{
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
# XXX
'raw_bev_embeddings'
:
_bev_feats
}
preds_dict
=
self
.
head
(
batch
=
{},
context
=
context
,
condition_on_det
=
True
,
gt_condition
=
False
,
only_det
=
self
.
only_det
)
# Hard Code
if
preds_dict
is
None
:
return
[
None
]
results_list
=
self
.
head
.
post_process
(
preds_dict
,
token
,
only_det
=
self
.
only_det
)
return
results_list
def
batch_data
(
self
,
polys
,
imgs
,
img_metas
,
device
,
points
=
None
):
# filter none vector's case
valid_idx
=
[
i
for
i
in
range
(
len
(
polys
))
if
len
(
polys
[
i
])]
imgs
=
imgs
[
valid_idx
]
img_metas
=
[
img_metas
[
i
]
for
i
in
valid_idx
]
polys
=
[
polys
[
i
]
for
i
in
valid_idx
]
if
points
is
not
None
:
points
=
[
points
[
i
]
for
i
in
valid_idx
]
points
=
self
.
batch_points
(
points
)
if
len
(
valid_idx
)
==
0
:
return
None
,
None
,
None
,
valid_idx
,
None
batch
=
{}
batch
[
'det'
]
=
format_det
(
polys
,
device
)
batch
[
'gen'
]
=
format_gen
(
polys
,
device
)
return
batch
,
imgs
,
img_metas
,
valid_idx
,
points
def
batch_points
(
self
,
points
):
pad_points
=
pad_sequence
(
points
,
batch_first
=
True
)
points_mask
=
torch
.
zeros_like
(
pad_points
[:,:,
0
]).
bool
()
for
i
in
range
(
len
(
points
)):
valid_num
=
points
[
i
].
shape
[
0
]
points_mask
[
i
][:
valid_num
]
=
True
return
(
pad_points
,
points_mask
)
def
format_det
(
polys
,
device
):
batch
=
{
'class_label'
:[],
'batch_idx'
:[],
'bbox'
:
[],
}
for
batch_idx
,
poly
in
enumerate
(
polys
):
keypoint_label
=
torch
.
from_numpy
(
poly
[
'det_label'
]).
to
(
device
)
keypoint
=
torch
.
from_numpy
(
poly
[
'keypoint'
]).
to
(
device
)
batch
[
'class_label'
].
append
(
keypoint_label
)
batch
[
'bbox'
].
append
(
keypoint
)
return
batch
def
format_gen
(
polys
,
device
):
line_cls
=
[]
polylines
,
polyline_masks
,
polyline_weights
=
[],
[],
[]
bbox
,
line_cls
,
line_bs_idx
=
[],
[],
[]
for
batch_idx
,
poly
in
enumerate
(
polys
):
# convert to cuda tensor
for
k
in
poly
.
keys
():
if
isinstance
(
poly
[
k
],
np
.
ndarray
):
poly
[
k
]
=
torch
.
from_numpy
(
poly
[
k
]).
to
(
device
)
else
:
poly
[
k
]
=
[
torch
.
from_numpy
(
v
).
to
(
device
)
for
v
in
poly
[
k
]]
line_cls
+=
poly
[
'gen_label'
]
line_bs_idx
+=
[
batch_idx
]
*
len
(
poly
[
'gen_label'
])
# condition
bbox
+=
poly
[
'qkeypoint'
]
# out
polylines
+=
poly
[
'polylines'
]
polyline_masks
+=
poly
[
'polyline_masks'
]
polyline_weights
+=
poly
[
'polyline_weights'
]
batch
=
{}
batch
[
'lines_bs_idx'
]
=
torch
.
tensor
(
line_bs_idx
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'lines_cls'
]
=
torch
.
tensor
(
line_cls
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'bbox_flat'
]
=
torch
.
stack
(
bbox
,
0
)
# padding
batch
[
'polylines'
]
=
pad_sequence
(
polylines
,
batch_first
=
True
)
batch
[
'polyline_masks'
]
=
pad_sequence
(
polyline_masks
,
batch_first
=
True
)
batch
[
'polyline_weights'
]
=
pad_sequence
(
polyline_weights
,
batch_first
=
True
)
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pad_sequence
from
torchvision.models.resnet
import
resnet18
,
resnet50
from
mmdet3d.models.builder
import
(
build_backbone
,
build_head
,
build_neck
)
from
.base_mapper
import
BaseMapper
,
MAPPERS
@
MAPPERS
.
register_module
()
class
VectorMapNet
(
BaseMapper
):
def
__init__
(
self
,
backbone_cfg
=
dict
(),
head_cfg
=
dict
(
vert_net_cfg
=
dict
(),
face_net_cfg
=
dict
(),
),
neck_input_channels
=
128
,
neck_cfg
=
None
,
with_auxiliary_head
=
False
,
only_det
=
False
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
model_name
=
None
,
**
kwargs
):
super
(
VectorMapNet
,
self
).
__init__
()
#Attribute
self
.
model_name
=
model_name
self
.
last_epoch
=
None
self
.
only_det
=
only_det
self
.
backbone
=
build_backbone
(
backbone_cfg
)
if
neck_cfg
is
not
None
:
self
.
neck_neck
=
build_backbone
(
neck_cfg
.
backbone
)
self
.
neck_neck
.
conv1
=
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
neck_project
=
build_neck
(
neck_cfg
.
neck
)
self
.
neck
=
self
.
multiscale_neck
else
:
trunk
=
resnet18
(
pretrained
=
False
,
zero_init_residual
=
True
)
self
.
neck
=
nn
.
Sequential
(
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
(
7
,
7
),
stride
=
(
2
,
2
),
padding
=
(
3
,
3
),
bias
=
False
),
nn
.
BatchNorm2d
(
64
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
dilation
=
1
,
ceil_mode
=
False
),
trunk
.
layer1
,
nn
.
Conv2d
(
64
,
128
,
kernel_size
=
1
,
bias
=
False
),
)
# BEV
if
hasattr
(
self
.
backbone
,
'bev_w'
):
self
.
bev_w
=
self
.
backbone
.
bev_w
self
.
bev_h
=
self
.
backbone
.
bev_h
self
.
head
=
build_head
(
head_cfg
)
def
multiscale_neck
(
self
,
bev_embedding
):
multi_feat
=
self
.
neck_neck
(
bev_embedding
)
multi_feat
=
self
.
neck_project
(
multi_feat
)
return
multi_feat
def
forward_train
(
self
,
img
,
polys
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- length: int
- label: int
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
Out:
loss, log_vars, num_sample
'''
# prepare labels and images
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
batch_data
(
polys
,
img
,
img_metas
,
img
.
device
,
points
)
# corner cases use hard code to prevent code fail
if
self
.
last_epoch
is
None
:
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
if
len
(
valid_idx
)
==
0
:
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
last_epoch
else
:
del
self
.
last_epoch
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
# Backbone
_bev_feats
=
self
.
backbone
(
img
,
img_metas
=
img_metas
,
points
=
points
)
img_shape
=
\
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
preds_dict
,
losses_dict
=
\
self
.
head
(
batch
,
context
=
{
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
'raw_bev_embeddings'
:
_bev_feats
},
only_det
=
self
.
only_det
)
# format outputs
loss
=
0
for
name
,
var
in
losses_dict
.
items
():
loss
=
loss
+
var
# update the log
log_vars
=
{
k
:
v
.
item
()
for
k
,
v
in
losses_dict
.
items
()}
log_vars
.
update
({
'total'
:
loss
.
item
()})
num_sample
=
img
.
size
(
0
)
return
loss
,
log_vars
,
num_sample
@
torch
.
no_grad
()
def
forward_test
(
self
,
img
,
polys
=
None
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
inference pipeline
'''
# prepare labels and images
token
=
[]
for
img_meta
in
img_metas
:
token
.
append
(
img_meta
[
'token'
])
_bev_feats
=
self
.
backbone
(
img
,
img_metas
,
points
=
points
)
img_shape
=
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
context
=
{
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
# XXX
'raw_bev_embeddings'
:
_bev_feats
}
preds_dict
=
self
.
head
(
batch
=
{},
context
=
context
,
condition_on_det
=
True
,
gt_condition
=
False
,
only_det
=
self
.
only_det
)
# Hard Code
if
preds_dict
is
None
:
return
[
None
]
results_list
=
self
.
head
.
post_process
(
preds_dict
,
token
,
only_det
=
self
.
only_det
)
return
results_list
def
batch_data
(
self
,
polys
,
imgs
,
img_metas
,
device
,
points
=
None
):
# filter none vector's case
valid_idx
=
[
i
for
i
in
range
(
len
(
polys
))
if
len
(
polys
[
i
])]
imgs
=
imgs
[
valid_idx
]
img_metas
=
[
img_metas
[
i
]
for
i
in
valid_idx
]
polys
=
[
polys
[
i
]
for
i
in
valid_idx
]
if
points
is
not
None
:
points
=
[
points
[
i
]
for
i
in
valid_idx
]
points
=
self
.
batch_points
(
points
)
if
len
(
valid_idx
)
==
0
:
return
None
,
None
,
None
,
valid_idx
,
None
batch
=
{}
batch
[
'det'
]
=
format_det
(
polys
,
device
)
batch
[
'gen'
]
=
format_gen
(
polys
,
device
)
return
batch
,
imgs
,
img_metas
,
valid_idx
,
points
def
batch_points
(
self
,
points
):
pad_points
=
pad_sequence
(
points
,
batch_first
=
True
)
points_mask
=
torch
.
zeros_like
(
pad_points
[:,:,
0
]).
bool
()
for
i
in
range
(
len
(
points
)):
valid_num
=
points
[
i
].
shape
[
0
]
points_mask
[
i
][:
valid_num
]
=
True
return
(
pad_points
,
points_mask
)
def
format_det
(
polys
,
device
):
batch
=
{
'class_label'
:[],
'batch_idx'
:[],
'bbox'
:
[],
}
for
batch_idx
,
poly
in
enumerate
(
polys
):
keypoint_label
=
torch
.
from_numpy
(
poly
[
'det_label'
]).
to
(
device
)
keypoint
=
torch
.
from_numpy
(
poly
[
'keypoint'
]).
to
(
device
)
batch
[
'class_label'
].
append
(
keypoint_label
)
batch
[
'bbox'
].
append
(
keypoint
)
return
batch
def
format_gen
(
polys
,
device
):
line_cls
=
[]
polylines
,
polyline_masks
,
polyline_weights
=
[],
[],
[]
bbox
,
line_cls
,
line_bs_idx
=
[],
[],
[]
for
batch_idx
,
poly
in
enumerate
(
polys
):
# convert to cuda tensor
for
k
in
poly
.
keys
():
if
isinstance
(
poly
[
k
],
np
.
ndarray
):
poly
[
k
]
=
torch
.
from_numpy
(
poly
[
k
]).
to
(
device
)
else
:
poly
[
k
]
=
[
torch
.
from_numpy
(
v
).
to
(
device
)
for
v
in
poly
[
k
]]
line_cls
+=
poly
[
'gen_label'
]
line_bs_idx
+=
[
batch_idx
]
*
len
(
poly
[
'gen_label'
])
# condition
bbox
+=
poly
[
'qkeypoint'
]
# out
polylines
+=
poly
[
'polylines'
]
polyline_masks
+=
poly
[
'polyline_masks'
]
polyline_weights
+=
poly
[
'polyline_weights'
]
batch
=
{}
batch
[
'lines_bs_idx'
]
=
torch
.
tensor
(
line_bs_idx
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'lines_cls'
]
=
torch
.
tensor
(
line_cls
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'bbox_flat'
]
=
torch
.
stack
(
bbox
,
0
)
# padding
batch
[
'polylines'
]
=
pad_sequence
(
polylines
,
batch_first
=
True
)
batch
[
'polyline_masks'
]
=
pad_sequence
(
polyline_masks
,
batch_first
=
True
)
batch
[
'polyline_weights'
]
=
pad_sequence
(
polyline_weights
,
batch_first
=
True
)
return
batch
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/__init__.py
View file @
f3b13cad
from
.deformable_transformer
import
DeformableDetrTransformer_
,
DeformableDetrTransformerDecoder_
from
.deformable_transformer
import
DeformableDetrTransformer_
,
DeformableDetrTransformerDecoder_
from
.base_transformer
import
PlaceHolderEncoder
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/base_transformer.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/base_transformer.py
View file @
f3b13cad
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
MultiScaleDeformableAttention
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
PlaceHolderEncoder
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
embed_dims
=
None
,
**
kwargs
):
super
(
PlaceHolderEncoder
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
def
forward
(
self
,
*
args
,
query
=
None
,
**
kwargs
):
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
MultiScaleDeformableAttention
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
PlaceHolderEncoder
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
embed_dims
=
None
,
**
kwargs
):
super
(
PlaceHolderEncoder
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
def
forward
(
self
,
*
args
,
query
=
None
,
**
kwargs
):
return
query
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/deformable_transformer.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/deformable_transformer.py
View file @
f3b13cad
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_activation_layer
,
build_norm_layer
,
xavier_init
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
torch.nn.init
import
normal_
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet.models.utils.transformer
import
Transformer
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
.fp16_dattn
import
MultiScaleDeformableAttentionFp16
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DeformableDetrTransformerDecoder_
(
TransformerLayerSequence
):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
**
kwargs
):
super
(
DeformableDetrTransformerDecoder_
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
coord_dim
=
coord_dim
self
.
kp_coord_dim
=
kp_coord_dim
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
valid_ratios
=
None
,
reg_branches
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
\
reference_points
[:,
:,
None
,:
self
.
kp_coord_dim
]
*
\
valid_ratios
[:,
None
,:,:
self
.
kp_coord_dim
]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
[...,:
self
.
kp_coord_dim
],
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
new_reference_points
=
tmp
new_reference_points
[...,
:
self
.
kp_coord_dim
]
=
tmp
[
...,
:
self
.
kp_coord_dim
]
+
inverse_sigmoid
(
reference_points
)
new_reference_points
=
new_reference_points
.
sigmoid
()
if
reference_points
.
shape
[
-
1
]
==
3
and
self
.
kp_coord_dim
==
2
:
reference_points
[...,
-
1
]
=
tmp
[...,
-
1
].
sigmoid
().
detach
()
reference_points
[...,:
self
.
coord_dim
]
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
TRANSFORMER
.
register_module
()
class
DeformableDetrTransformer_
(
Transformer
):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
as_two_stage
=
False
,
num_feature_levels
=
1
,
two_stage_num_proposals
=
300
,
coord_dim
=
2
,
**
kwargs
):
super
(
DeformableDetrTransformer_
,
self
).
__init__
(
**
kwargs
)
self
.
as_two_stage
=
as_two_stage
self
.
num_feature_levels
=
num_feature_levels
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
embed_dims
=
self
.
encoder
.
embed_dims
self
.
coord_dim
=
coord_dim
self
.
init_layers
()
def
init_layers
(
self
):
"""Initialize layers of the DeformableDetrTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
if
self
.
as_two_stage
:
self
.
enc_output
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
)
self
.
pos_trans
=
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
*
2
)
self
.
pos_trans_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
*
2
)
else
:
self
.
reference_points_embed
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
)
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MultiScaleDeformableAttention
):
m
.
init_weights
()
elif
isinstance
(
m
,
MultiScaleDeformableAttentionFp16
):
m
.
init_weights
()
if
not
self
.
as_two_stage
:
xavier_init
(
self
.
reference_points_embed
,
distribution
=
'uniform'
,
bias
=
0.
)
normal_
(
self
.
level_embeds
)
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list
=
[]
for
lvl
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
):
# TODO check this 0.5
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
torch
.
float32
,
device
=
device
))
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
1
]
*
H
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
0
]
*
W
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
return
reference_points
def
get_valid_ratio
(
self
,
mask
):
"""Get the valid radios of feature maps of all level."""
_
,
H
,
W
=
mask
.
shape
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
get_proposal_pos_embed
(
self
,
proposals
,
num_pos_feats
=
128
,
temperature
=
10000
):
"""Get the position embedding of proposal."""
scale
=
2
*
math
.
pi
dim_t
=
torch
.
arange
(
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
proposals
.
device
)
dim_t
=
temperature
**
(
2
*
(
dim_t
//
2
)
/
num_pos_feats
)
# N, L, 4
proposals
=
proposals
.
sigmoid
()
*
scale
# N, L, 4, 128
pos
=
proposals
[:,
:,
:,
None
]
/
dim_t
# N, L, 4, 64, 2
pos
=
torch
.
stack
((
pos
[:,
:,
:,
0
::
2
].
sin
(),
pos
[:,
:,
:,
1
::
2
].
cos
()),
dim
=
4
).
flatten
(
2
)
return
pos
def
forward
(
self
,
mlvl_feats
,
mlvl_masks
,
query_embed
,
mlvl_pos_embeds
,
reg_branches
=
None
,
cls_branches
=
None
,
**
kwargs
):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert
self
.
as_two_stage
or
query_embed
is
not
None
feat_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
(
feat
,
mask
,
pos_embed
)
in
enumerate
(
zip
(
mlvl_feats
,
mlvl_masks
,
mlvl_pos_embeds
)):
bs
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
spatial_shapes
.
append
(
spatial_shape
)
feat
=
feat
.
flatten
(
2
).
transpose
(
1
,
2
)
mask
=
mask
.
flatten
(
1
)
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
lvl_pos_embed
=
pos_embed
+
self
.
level_embeds
[
lvl
].
view
(
1
,
1
,
-
1
)
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
feat_flatten
.
append
(
feat
)
mask_flatten
.
append
(
mask
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
1
)
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
feat_flatten
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,
)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
valid_ratios
=
torch
.
stack
(
[
self
.
get_valid_ratio
(
m
)
for
m
in
mlvl_masks
],
1
)
# reference_points = \
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# device=feat.device)
feat_flatten
=
feat_flatten
.
permute
(
1
,
0
,
2
)
# (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# query=feat_flatten,
# key=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# **kwargs)
memory
=
feat_flatten
.
permute
(
1
,
0
,
2
)
bs
,
_
,
c
=
memory
.
shape
query_pos
,
query
=
torch
.
split
(
query_embed
,
c
,
dim
=-
1
)
reference_points
=
self
.
reference_points_embed
(
query_pos
).
sigmoid
()
init_reference_out
=
reference_points
# decoder
query
=
query
.
permute
(
1
,
0
,
2
)
memory
=
memory
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
memory
,
query_pos
=
query_pos
,
key_padding_mask
=
mask_flatten
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
valid_ratios
=
valid_ratios
,
reg_branches
=
reg_branches
,
**
kwargs
)
inter_references_out
=
inter_references
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_activation_layer
,
build_norm_layer
,
xavier_init
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
torch.nn.init
import
normal_
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet.models.utils.transformer
import
Transformer
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
.fp16_dattn
import
MultiScaleDeformableAttentionFp16
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DeformableDetrTransformerDecoder_
(
TransformerLayerSequence
):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
**
kwargs
):
super
(
DeformableDetrTransformerDecoder_
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
coord_dim
=
coord_dim
self
.
kp_coord_dim
=
kp_coord_dim
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
valid_ratios
=
None
,
reg_branches
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
\
reference_points
[:,
:,
None
,:
self
.
kp_coord_dim
]
*
\
valid_ratios
[:,
None
,:,:
self
.
kp_coord_dim
]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
[...,:
self
.
kp_coord_dim
],
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
new_reference_points
=
tmp
new_reference_points
[...,
:
self
.
kp_coord_dim
]
=
tmp
[
...,
:
self
.
kp_coord_dim
]
+
inverse_sigmoid
(
reference_points
)
new_reference_points
=
new_reference_points
.
sigmoid
()
if
reference_points
.
shape
[
-
1
]
==
3
and
self
.
kp_coord_dim
==
2
:
reference_points
[...,
-
1
]
=
tmp
[...,
-
1
].
sigmoid
().
detach
()
reference_points
[...,:
self
.
coord_dim
]
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
TRANSFORMER
.
register_module
()
class
DeformableDetrTransformer_
(
Transformer
):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
as_two_stage
=
False
,
num_feature_levels
=
1
,
two_stage_num_proposals
=
300
,
coord_dim
=
2
,
**
kwargs
):
super
(
DeformableDetrTransformer_
,
self
).
__init__
(
**
kwargs
)
self
.
as_two_stage
=
as_two_stage
self
.
num_feature_levels
=
num_feature_levels
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
embed_dims
=
self
.
encoder
.
embed_dims
self
.
coord_dim
=
coord_dim
self
.
init_layers
()
def
init_layers
(
self
):
"""Initialize layers of the DeformableDetrTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
if
self
.
as_two_stage
:
self
.
enc_output
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
)
self
.
pos_trans
=
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
*
2
)
self
.
pos_trans_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
*
2
)
else
:
self
.
reference_points_embed
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
)
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MultiScaleDeformableAttention
):
m
.
init_weights
()
elif
isinstance
(
m
,
MultiScaleDeformableAttentionFp16
):
m
.
init_weights
()
if
not
self
.
as_two_stage
:
xavier_init
(
self
.
reference_points_embed
,
distribution
=
'uniform'
,
bias
=
0.
)
normal_
(
self
.
level_embeds
)
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has
\
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list
=
[]
for
lvl
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
):
# TODO check this 0.5
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
torch
.
float32
,
device
=
device
))
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
1
]
*
H
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
0
]
*
W
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
return
reference_points
def
get_valid_ratio
(
self
,
mask
):
"""Get the valid radios of feature maps of all level."""
_
,
H
,
W
=
mask
.
shape
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
get_proposal_pos_embed
(
self
,
proposals
,
num_pos_feats
=
128
,
temperature
=
10000
):
"""Get the position embedding of proposal."""
scale
=
2
*
math
.
pi
dim_t
=
torch
.
arange
(
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
proposals
.
device
)
dim_t
=
temperature
**
(
2
*
(
dim_t
//
2
)
/
num_pos_feats
)
# N, L, 4
proposals
=
proposals
.
sigmoid
()
*
scale
# N, L, 4, 128
pos
=
proposals
[:,
:,
:,
None
]
/
dim_t
# N, L, 4, 64, 2
pos
=
torch
.
stack
((
pos
[:,
:,
:,
0
::
2
].
sin
(),
pos
[:,
:,
:,
1
::
2
].
cos
()),
dim
=
4
).
flatten
(
2
)
return
pos
def
forward
(
self
,
mlvl_feats
,
mlvl_masks
,
query_embed
,
mlvl_pos_embeds
,
reg_branches
=
None
,
cls_branches
=
None
,
**
kwargs
):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape
\
(num_dec_layers, bs, num_query, embed_dims), else has
\
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference
\
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference
\
points in decoder, has shape
\
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of
\
proposals generated from
\
encoder's feature maps, has shape
\
(batch, h*w, num_classes).
\
Only would be returned when `as_two_stage` is True,
\
otherwise None.
- enc_outputs_coord_unact: The regression results
\
generated from encoder's feature maps., has shape
\
(batch, h*w, 4). Only would
\
be returned when `as_two_stage` is True,
\
otherwise None.
"""
assert
self
.
as_two_stage
or
query_embed
is
not
None
feat_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
(
feat
,
mask
,
pos_embed
)
in
enumerate
(
zip
(
mlvl_feats
,
mlvl_masks
,
mlvl_pos_embeds
)):
bs
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
spatial_shapes
.
append
(
spatial_shape
)
feat
=
feat
.
flatten
(
2
).
transpose
(
1
,
2
)
mask
=
mask
.
flatten
(
1
)
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
lvl_pos_embed
=
pos_embed
+
self
.
level_embeds
[
lvl
].
view
(
1
,
1
,
-
1
)
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
feat_flatten
.
append
(
feat
)
mask_flatten
.
append
(
mask
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
1
)
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
feat_flatten
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,
)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
valid_ratios
=
torch
.
stack
(
[
self
.
get_valid_ratio
(
m
)
for
m
in
mlvl_masks
],
1
)
# reference_points = \
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# device=feat.device)
feat_flatten
=
feat_flatten
.
permute
(
1
,
0
,
2
)
# (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# query=feat_flatten,
# key=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# **kwargs)
memory
=
feat_flatten
.
permute
(
1
,
0
,
2
)
bs
,
_
,
c
=
memory
.
shape
query_pos
,
query
=
torch
.
split
(
query_embed
,
c
,
dim
=-
1
)
reference_points
=
self
.
reference_points_embed
(
query_pos
).
sigmoid
()
init_reference_out
=
reference_points
# decoder
query
=
query
.
permute
(
1
,
0
,
2
)
memory
=
memory
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
memory
,
query_pos
=
query_pos
,
key_padding_mask
=
mask_flatten
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
valid_ratios
=
valid_ratios
,
reg_branches
=
reg_branches
,
**
kwargs
)
inter_references_out
=
inter_references
return
inter_states
,
init_reference_out
,
inter_references_out
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/fp16_dattn.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/fp16_dattn.py
View file @
f3b13cad
from
turtle
import
forward
import
warnings
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv
import
deprecated_api_warning
from
mmcv.cnn
import
constant_init
,
xavier_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner
import
BaseModule
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFp16
(
BaseModule
):
def
__init__
(
self
,
attn_cfg
=
None
,
init_cfg
=
None
,
**
kwarg
):
super
(
MultiScaleDeformableAttentionFp16
,
self
).
__init__
(
init_cfg
)
# import ipdb; ipdb.set_trace()
self
.
deformable_attention
=
build_attention
(
attn_cfg
)
self
.
deformable_attention
.
init_weights
()
self
.
fp16_enabled
=
False
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points'
,
'identity'
))
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
# import ipdb; ipdb.set_trace()
return
self
.
deformable_attention
(
query
,
key
=
key
,
value
=
value
,
identity
=
identity
,
query_pos
=
query_pos
,
key_padding_mask
=
key_padding_mask
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
class
MultiScaleDeformableAttnFunctionFp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
def
multi_scale_deformable_attn_pytorch
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
\
sampling_locations
.
shape
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
dim
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
sampling_grid_l_
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
output
=
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
attention_weights
).
sum
(
-
1
).
view
(
bs
,
num_heads
*
embed_dims
,
num_queries
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFP32
(
BaseModule
):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
False
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
():
output
=
MultiScaleDeformableAttnFunctionFp32
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
from
turtle
import
forward
import
warnings
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv
import
deprecated_api_warning
from
mmcv.cnn
import
constant_init
,
xavier_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner
import
BaseModule
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFp16
(
BaseModule
):
def
__init__
(
self
,
attn_cfg
=
None
,
init_cfg
=
None
,
**
kwarg
):
super
(
MultiScaleDeformableAttentionFp16
,
self
).
__init__
(
init_cfg
)
# import ipdb; ipdb.set_trace()
self
.
deformable_attention
=
build_attention
(
attn_cfg
)
self
.
deformable_attention
.
init_weights
()
self
.
fp16_enabled
=
False
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points'
,
'identity'
))
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
# import ipdb; ipdb.set_trace()
return
self
.
deformable_attention
(
query
,
key
=
key
,
value
=
value
,
identity
=
identity
,
query_pos
=
query_pos
,
key_padding_mask
=
key_padding_mask
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
class
MultiScaleDeformableAttnFunctionFp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
def
multi_scale_deformable_attn_pytorch
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
\
sampling_locations
.
shape
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
dim
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
sampling_grid_l_
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
output
=
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
attention_weights
).
sum
(
-
1
).
view
(
bs
,
num_heads
*
embed_dims
,
num_queries
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFP32
(
BaseModule
):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
False
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
():
output
=
MultiScaleDeformableAttnFunctionFp32
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/dist_test.sh
→
autonomous_driving/Online-HD-Map-Construction/tools/dist_test.sh
View file @
f3b13cad
#!/usr/bin/env bash
CONFIG
=
$1
CHECKPOINT
=
$2
GPUS
=
$3
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
#!/usr/bin/env bash
CONFIG
=
$1
CHECKPOINT
=
$2
GPUS
=
$3
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/dist_train.sh
→
autonomous_driving/Online-HD-Map-Construction/tools/dist_train.sh
View file @
f3b13cad
#!/usr/bin/env bash
CONFIG
=
$1
GPUS
=
$2
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/train.py
$CONFIG
--launcher
pytorch
${
@
:3
}
#!/usr/bin/env bash
CONFIG
=
$1
GPUS
=
$2
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/train.py
$CONFIG
--launcher
pytorch
${
@
:3
}
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/evaluate_submission.py
→
autonomous_driving/Online-HD-Map-Construction/tools/evaluate_submission.py
View file @
f3b13cad
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
from
src.datasets.evaluation.vector_eval
import
VectorEvaluate
import
argparse
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a submission file'
)
parser
.
add_argument
(
'submission'
,
help
=
'submission file in pickle or json format to be evaluated'
)
parser
.
add_argument
(
'gt'
,
help
=
'gt annotation file'
)
args
=
parser
.
parse_args
()
return
args
def
main
(
args
):
evaluator
=
VectorEvaluate
(
args
.
gt
,
n_workers
=
0
)
results
=
evaluator
.
evaluate
(
args
.
submission
)
print
(
results
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
main
(
args
)
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
from
src.datasets.evaluation.vector_eval
import
VectorEvaluate
import
argparse
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a submission file'
)
parser
.
add_argument
(
'submission'
,
help
=
'submission file in pickle or json format to be evaluated'
)
parser
.
add_argument
(
'gt'
,
help
=
'gt annotation file'
)
args
=
parser
.
parse_args
()
return
args
def
main
(
args
):
evaluator
=
VectorEvaluate
(
args
.
gt
,
n_workers
=
0
)
results
=
evaluator
.
evaluate
(
args
.
submission
)
print
(
results
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
main
(
args
)
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/mmdet_test.py
→
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_test.py
View file @
f3b13cad
import
os.path
as
osp
import
pickle
import
shutil
import
tempfile
import
time
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv.image
import
tensor2imgs
from
mmcv.runner
import
get_dist_info
from
mmdet.core
import
encode_mask_results
def
single_gpu_test
(
model
,
data_loader
,
show
=
False
,
out_dir
=
None
,
show_score_thr
=
0.3
):
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
batch_size
=
len
(
result
)
if
show
or
out_dir
:
if
batch_size
==
1
and
isinstance
(
data
[
'img'
][
0
],
torch
.
Tensor
):
img_tensor
=
data
[
'img'
][
0
]
else
:
img_tensor
=
data
[
'img'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
imgs
=
tensor2imgs
(
img_tensor
,
**
img_metas
[
0
][
'img_norm_cfg'
])
assert
len
(
imgs
)
==
len
(
img_metas
)
for
i
,
(
img
,
img_meta
)
in
enumerate
(
zip
(
imgs
,
img_metas
)):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
img_show
=
img
[:
h
,
:
w
,
:]
ori_h
,
ori_w
=
img_meta
[
'ori_shape'
][:
-
1
]
img_show
=
mmcv
.
imresize
(
img_show
,
(
ori_w
,
ori_h
))
if
out_dir
:
out_file
=
osp
.
join
(
out_dir
,
img_meta
[
'ori_filename'
])
else
:
out_file
=
None
model
.
module
.
show_result
(
img_show
,
result
[
i
],
show
=
show
,
out_file
=
out_file
,
score_thr
=
show_score_thr
)
# encode mask results
if
isinstance
(
result
[
0
],
tuple
):
result
=
[(
bbox_results
,
encode_mask_results
(
mask_results
))
for
bbox_results
,
mask_results
in
result
]
results
.
extend
(
result
)
for
_
in
range
(
batch_size
):
prog_bar
.
update
()
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
time
.
sleep
(
2
)
# This line can prevent deadlock problem in some cases.
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
# encode mask results
# if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result]
results
.
extend
(
result
)
if
rank
==
0
:
batch_size
=
len
(
result
)
for
_
in
range
(
batch_size
*
world_size
):
prog_bar
.
update
()
# collect results from all ranks
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
rank
,
world_size
=
get_dist_info
()
# create a tmp dir if it is not specified
if
tmpdir
is
None
:
MAX_LEN
=
512
# 32 is whitespace
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
32
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
if
rank
==
0
:
mmcv
.
mkdir_or_exist
(
'.dist_test'
)
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
'.dist_test'
)
tmpdir
=
torch
.
tensor
(
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dist
.
broadcast
(
dir_tensor
,
0
)
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
dist
.
barrier
()
# collect all parts
if
rank
!=
0
:
return
None
else
:
# load results of all parts from tmp dir
part_list
=
[]
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_list
.
append
(
mmcv
.
load
(
part_file
))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
):
rank
,
world_size
=
get_dist_info
()
# dump result part to tensor with pickle
part_tensor
=
torch
.
tensor
(
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# gather all result part tensor shape
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
shape_list
,
shape_tensor
)
# padding result part tensor to max length
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_recv_list
=
[
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
]
# gather all result part
dist
.
all_gather
(
part_recv_list
,
part_send
)
if
rank
==
0
:
part_list
=
[]
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
part_list
.
append
(
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
()))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
import
os.path
as
osp
import
pickle
import
shutil
import
tempfile
import
time
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv.image
import
tensor2imgs
from
mmcv.runner
import
get_dist_info
from
mmdet.core
import
encode_mask_results
def
single_gpu_test
(
model
,
data_loader
,
show
=
False
,
out_dir
=
None
,
show_score_thr
=
0.3
):
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
batch_size
=
len
(
result
)
if
show
or
out_dir
:
if
batch_size
==
1
and
isinstance
(
data
[
'img'
][
0
],
torch
.
Tensor
):
img_tensor
=
data
[
'img'
][
0
]
else
:
img_tensor
=
data
[
'img'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
imgs
=
tensor2imgs
(
img_tensor
,
**
img_metas
[
0
][
'img_norm_cfg'
])
assert
len
(
imgs
)
==
len
(
img_metas
)
for
i
,
(
img
,
img_meta
)
in
enumerate
(
zip
(
imgs
,
img_metas
)):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
img_show
=
img
[:
h
,
:
w
,
:]
ori_h
,
ori_w
=
img_meta
[
'ori_shape'
][:
-
1
]
img_show
=
mmcv
.
imresize
(
img_show
,
(
ori_w
,
ori_h
))
if
out_dir
:
out_file
=
osp
.
join
(
out_dir
,
img_meta
[
'ori_filename'
])
else
:
out_file
=
None
model
.
module
.
show_result
(
img_show
,
result
[
i
],
show
=
show
,
out_file
=
out_file
,
score_thr
=
show_score_thr
)
# encode mask results
if
isinstance
(
result
[
0
],
tuple
):
result
=
[(
bbox_results
,
encode_mask_results
(
mask_results
))
for
bbox_results
,
mask_results
in
result
]
results
.
extend
(
result
)
for
_
in
range
(
batch_size
):
prog_bar
.
update
()
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
time
.
sleep
(
2
)
# This line can prevent deadlock problem in some cases.
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
# encode mask results
# if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result]
results
.
extend
(
result
)
if
rank
==
0
:
batch_size
=
len
(
result
)
for
_
in
range
(
batch_size
*
world_size
):
prog_bar
.
update
()
# collect results from all ranks
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
rank
,
world_size
=
get_dist_info
()
# create a tmp dir if it is not specified
if
tmpdir
is
None
:
MAX_LEN
=
512
# 32 is whitespace
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
32
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
if
rank
==
0
:
mmcv
.
mkdir_or_exist
(
'.dist_test'
)
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
'.dist_test'
)
tmpdir
=
torch
.
tensor
(
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dist
.
broadcast
(
dir_tensor
,
0
)
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
dist
.
barrier
()
# collect all parts
if
rank
!=
0
:
return
None
else
:
# load results of all parts from tmp dir
part_list
=
[]
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_list
.
append
(
mmcv
.
load
(
part_file
))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
):
rank
,
world_size
=
get_dist_info
()
# dump result part to tensor with pickle
part_tensor
=
torch
.
tensor
(
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# gather all result part tensor shape
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
shape_list
,
shape_tensor
)
# padding result part tensor to max length
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_recv_list
=
[
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
]
# gather all result part
dist
.
all_gather
(
part_recv_list
,
part_send
)
if
rank
==
0
:
part_list
=
[]
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
part_list
.
append
(
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
()))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/mmdet_train.py
→
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_train.py
View file @
f3b13cad
import
random
import
warnings
import
numpy
as
np
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
HOOKS
,
DistSamplerSeedHook
,
EpochBasedRunner
,
Fp16OptimizerHook
,
OptimizerHook
,
build_optimizer
,
build_runner
)
from
mmcv.utils
import
build_from_cfg
from
mmdet.core
import
DistEvalHook
,
EvalHook
from
mmdet.datasets
import
(
build_dataloader
,
build_dataset
,
replace_ImageToTensor
)
from
mmdet.utils
import
get_root_logger
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
train_detector
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
if
'imgs_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
'"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead'
)
if
'samples_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
f
'Got "imgs_per_gpu"=
{
cfg
.
data
.
imgs_per_gpu
}
and '
f
'"samples_per_gpu"=
{
cfg
.
data
.
samples_per_gpu
}
, "imgs_per_gpu"'
f
'=
{
cfg
.
data
.
imgs_per_gpu
}
is used in this experiments'
)
else
:
logger
.
warning
(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f
'
{
cfg
.
data
.
imgs_per_gpu
}
in this experiments'
)
cfg
.
data
.
samples_per_gpu
=
cfg
.
data
.
imgs_per_gpu
data_loaders
=
[
build_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
)
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
'runner'
not
in
cfg
:
cfg
.
runner
=
{
'type'
:
'EpochBasedRunner'
,
'max_epochs'
:
cfg
.
total_epochs
}
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
else
:
if
'total_epochs'
in
cfg
:
assert
cfg
.
total_epochs
==
cfg
.
runner
.
max_epochs
runner
=
build_runner
(
cfg
.
runner
,
default_args
=
dict
(
model
=
model
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
))
# an ugly workaround to make .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
optimizer_config
=
cfg
.
optimizer_config
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
if
distributed
:
if
isinstance
(
runner
,
EpochBasedRunner
):
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
validate
:
# Support batch_size > 1 in validation
val_samples_per_gpu
=
cfg
.
data
.
val
.
pop
(
'samples_per_gpu'
,
1
)
if
val_samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
val
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
val
.
pipeline
)
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_dataloader
(
val_dataset
,
samples_per_gpu
=
val_samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
))
# user-defined hooks
if
cfg
.
get
(
'custom_hooks'
,
None
):
custom_hooks
=
cfg
.
custom_hooks
assert
isinstance
(
custom_hooks
,
list
),
\
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
for
hook_cfg
in
cfg
.
custom_hooks
:
assert
isinstance
(
hook_cfg
,
dict
),
\
'Each item in custom_hooks expects dict type, but got '
\
f
'
{
type
(
hook_cfg
)
}
'
hook_cfg
=
hook_cfg
.
copy
()
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
import
random
import
warnings
import
numpy
as
np
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
HOOKS
,
DistSamplerSeedHook
,
EpochBasedRunner
,
Fp16OptimizerHook
,
OptimizerHook
,
build_optimizer
,
build_runner
)
from
mmcv.utils
import
build_from_cfg
from
mmdet.core
import
DistEvalHook
,
EvalHook
from
mmdet.datasets
import
(
build_dataloader
,
build_dataset
,
replace_ImageToTensor
)
from
mmdet.utils
import
get_root_logger
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
train_detector
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
if
'imgs_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
'"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead'
)
if
'samples_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
f
'Got "imgs_per_gpu"=
{
cfg
.
data
.
imgs_per_gpu
}
and '
f
'"samples_per_gpu"=
{
cfg
.
data
.
samples_per_gpu
}
, "imgs_per_gpu"'
f
'=
{
cfg
.
data
.
imgs_per_gpu
}
is used in this experiments'
)
else
:
logger
.
warning
(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f
'
{
cfg
.
data
.
imgs_per_gpu
}
in this experiments'
)
cfg
.
data
.
samples_per_gpu
=
cfg
.
data
.
imgs_per_gpu
data_loaders
=
[
build_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
)
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
'runner'
not
in
cfg
:
cfg
.
runner
=
{
'type'
:
'EpochBasedRunner'
,
'max_epochs'
:
cfg
.
total_epochs
}
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
else
:
if
'total_epochs'
in
cfg
:
assert
cfg
.
total_epochs
==
cfg
.
runner
.
max_epochs
runner
=
build_runner
(
cfg
.
runner
,
default_args
=
dict
(
model
=
model
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
))
# an ugly workaround to make .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
optimizer_config
=
cfg
.
optimizer_config
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
if
distributed
:
if
isinstance
(
runner
,
EpochBasedRunner
):
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
validate
:
# Support batch_size > 1 in validation
val_samples_per_gpu
=
cfg
.
data
.
val
.
pop
(
'samples_per_gpu'
,
1
)
if
val_samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
val
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
val
.
pipeline
)
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_dataloader
(
val_dataset
,
samples_per_gpu
=
val_samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
))
# user-defined hooks
if
cfg
.
get
(
'custom_hooks'
,
None
):
custom_hooks
=
cfg
.
custom_hooks
assert
isinstance
(
custom_hooks
,
list
),
\
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
for
hook_cfg
in
cfg
.
custom_hooks
:
assert
isinstance
(
hook_cfg
,
dict
),
\
'Each item in custom_hooks expects dict type, but got '
\
f
'
{
type
(
hook_cfg
)
}
'
hook_cfg
=
hook_cfg
.
copy
()
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/test.py
→
autonomous_driving/Online-HD-Map-Construction/tools/test.py
View file @
f3b13cad
import
argparse
import
mmcv
import
os
import
os.path
as
osp
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataloader
,
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet_test
import
multi_gpu_test
from
mmdet_train
import
set_random_seed
from
mmdet.datasets
import
replace_ImageToTensor
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet test (and eval) a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
type
=
str
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
required
=
True
,
help
=
'which split to test on'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--fuse-conv-bn'
,
action
=
'store_true'
,
help
=
'Whether to fuse conv and bn, this will slightly increase'
'the inference speed'
)
parser
.
add_argument
(
'--format-only'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
help
=
'whether to run evaluation.'
)
parser
.
add_argument
(
'--gpu-collect'
,
action
=
'store_true'
,
help
=
'whether to use gpu to collect results.'
)
parser
.
add_argument
(
'--tmpdir'
,
help
=
'tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
main
():
args
=
parse_args
()
if
args
.
split
not
in
[
'val'
,
'test'
]:
raise
ValueError
(
'Please choose "val" or "test" split for testing'
)
if
(
args
.
eval
and
args
.
format_only
)
or
(
not
args
.
eval
and
not
args
.
format_only
):
raise
ValueError
(
'Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"'
)
if
args
.
eval
and
args
.
split
==
'test'
:
raise
ValueError
(
'Cannot evaluate on test set'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules from plguin/xx, registry will be updated
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
cfg_data_dict
=
cfg
.
data
.
get
(
args
.
split
)
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
samples_per_gpu
=
1
cfg_data_dict
.
test_mode
=
True
samples_per_gpu
=
cfg_data_dict
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict
.
pipeline
=
replace_ImageToTensor
(
cfg_data_dict
.
pipeline
)
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the dataloader
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
cfg_data_dict
.
work_dir
=
cfg
.
work_dir
print
(
'work_dir: '
,
cfg
.
work_dir
)
dataset
=
build_dataset
(
cfg_data_dict
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
args
.
fuse_conv_bn
:
model
=
fuse_conv_bn
(
model
)
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
outputs
=
single_gpu_test
(
model
,
data_loader
)
else
:
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
)
outputs
=
multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
args
.
gpu_collect
)
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
prefix
=
cfg
.
work_dir
)
elif
args
.
eval
:
print
(
'start evaluation!'
)
print
(
dataset
.
evaluate
(
outputs
))
if
__name__
==
'__main__'
:
main
()
import
argparse
import
mmcv
import
os
import
os.path
as
osp
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataloader
,
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet_test
import
multi_gpu_test
from
mmdet_train
import
set_random_seed
from
mmdet.datasets
import
replace_ImageToTensor
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet test (and eval) a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
type
=
str
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
required
=
True
,
help
=
'which split to test on'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--fuse-conv-bn'
,
action
=
'store_true'
,
help
=
'Whether to fuse conv and bn, this will slightly increase'
'the inference speed'
)
parser
.
add_argument
(
'--format-only'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
help
=
'whether to run evaluation.'
)
parser
.
add_argument
(
'--gpu-collect'
,
action
=
'store_true'
,
help
=
'whether to use gpu to collect results.'
)
parser
.
add_argument
(
'--tmpdir'
,
help
=
'tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
main
():
args
=
parse_args
()
if
args
.
split
not
in
[
'val'
,
'test'
]:
raise
ValueError
(
'Please choose "val" or "test" split for testing'
)
if
(
args
.
eval
and
args
.
format_only
)
or
(
not
args
.
eval
and
not
args
.
format_only
):
raise
ValueError
(
'Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"'
)
if
args
.
eval
and
args
.
split
==
'test'
:
raise
ValueError
(
'Cannot evaluate on test set'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules from plguin/xx, registry will be updated
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
cfg_data_dict
=
cfg
.
data
.
get
(
args
.
split
)
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
samples_per_gpu
=
1
cfg_data_dict
.
test_mode
=
True
samples_per_gpu
=
cfg_data_dict
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict
.
pipeline
=
replace_ImageToTensor
(
cfg_data_dict
.
pipeline
)
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the dataloader
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
cfg_data_dict
.
work_dir
=
cfg
.
work_dir
print
(
'work_dir: '
,
cfg
.
work_dir
)
dataset
=
build_dataset
(
cfg_data_dict
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
args
.
fuse_conv_bn
:
model
=
fuse_conv_bn
(
model
)
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
outputs
=
single_gpu_test
(
model
,
data_loader
)
else
:
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
)
outputs
=
multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
args
.
gpu_collect
)
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
prefix
=
cfg
.
work_dir
)
elif
args
.
eval
:
print
(
'start evaluation!'
)
print
(
dataset
.
evaluate
(
outputs
))
if
__name__
==
'__main__'
:
main
()
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/train.py
→
autonomous_driving/Online-HD-Map-Construction/tools/train.py
View file @
f3b13cad
from
__future__
import
division
import
argparse
import
copy
import
mmcv
import
os
import
time
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
os
import
path
as
osp
from
mmdet
import
__version__
as
mmdet_version
from
mmdet3d
import
__version__
as
mmdet3d_version
from
mmdet3d.apis
import
train_model
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmseg
import
__version__
as
mmseg_version
# warper
from
mmdet_train
import
set_random_seed
# from builder import build_model
from
mmdet3d.models
import
build_model
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--autoscale-lr'
,
action
=
'store_true'
,
help
=
'automatically scale lr with the number of gpus'
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
cfg_options
:
raise
ValueError
(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
args
.
cfg_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules, registry will be updated
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
else
:
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
logger_name
=
'mmseg'
else
:
logger_name
=
'mmdet'
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
.
init_weights
()
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
cfg
.
data
.
train
.
work_dir
=
cfg
.
work_dir
cfg
.
data
.
val
.
work_dir
=
cfg
.
work_dir
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
# in case we use a dataset wrapper
if
'dataset'
in
cfg
.
data
.
train
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
else
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset
.
test_mode
=
False
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
mmdet_version
,
mmseg_version
=
mmseg_version
,
mmdet3d_version
=
mmdet3d_version
,
config
=
cfg
.
pretty_text
,
CLASSES
=
None
,
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
# add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES
train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
if
__name__
==
'__main__'
:
main
()
from
__future__
import
division
import
argparse
import
copy
import
mmcv
import
os
import
time
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
os
import
path
as
osp
from
mmdet
import
__version__
as
mmdet_version
from
mmdet3d
import
__version__
as
mmdet3d_version
from
mmdet3d.apis
import
train_model
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmseg
import
__version__
as
mmseg_version
# warper
from
mmdet_train
import
set_random_seed
# from builder import build_model
from
mmdet3d.models
import
build_model
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--autoscale-lr'
,
action
=
'store_true'
,
help
=
'automatically scale lr with the number of gpus'
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
cfg_options
:
raise
ValueError
(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
args
.
cfg_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules, registry will be updated
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
else
:
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
logger_name
=
'mmseg'
else
:
logger_name
=
'mmdet'
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
.
init_weights
()
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
cfg
.
data
.
train
.
work_dir
=
cfg
.
work_dir
cfg
.
data
.
val
.
work_dir
=
cfg
.
work_dir
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
# in case we use a dataset wrapper
if
'dataset'
in
cfg
.
data
.
train
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
else
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset
.
test_mode
=
False
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
mmdet_version
,
mmseg_version
=
mmseg_version
,
mmdet3d_version
=
mmdet3d_version
,
config
=
cfg
.
pretty_text
,
CLASSES
=
None
,
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
# add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES
train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
if
__name__
==
'__main__'
:
main
()
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/visualization/renderer.py
→
autonomous_driving/Online-HD-Map-Construction/tools/visualization/renderer.py
View file @
f3b13cad
import
os.path
as
osp
import
os
import
numpy
as
np
import
copy
import
cv2
import
matplotlib.pyplot
as
plt
from
PIL
import
Image
from
shapely.geometry
import
LineString
def
remove_nan_values
(
uv
):
is_u_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
0
]))
is_v_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
1
]))
is_uv_valid
=
np
.
logical_and
(
is_u_valid
,
is_v_valid
)
uv_valid
=
uv
[
is_uv_valid
]
return
uv_valid
def
points_ego2img
(
pts_ego
,
extrinsics
,
intrinsics
):
pts_ego_4d
=
np
.
concatenate
([
pts_ego
,
np
.
ones
([
len
(
pts_ego
),
1
])],
axis
=-
1
)
pts_cam_4d
=
extrinsics
@
pts_ego_4d
.
T
uv
=
(
intrinsics
@
pts_cam_4d
[:
3
,
:]).
T
uv
=
remove_nan_values
(
uv
)
depth
=
uv
[:,
2
]
uv
=
uv
[:,
:
2
]
/
uv
[:,
2
].
reshape
(
-
1
,
1
)
return
uv
,
depth
def
interp_fixed_dist
(
line
,
sample_dist
):
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances
=
list
(
np
.
arange
(
sample_dist
,
line
.
length
,
sample_dist
))
# make sure to sample at least two points when sample_dist > line.length
distances
=
[
0
,]
+
distances
+
[
line
.
length
,]
sampled_points
=
np
.
array
([
list
(
line
.
interpolate
(
distance
).
coords
)
for
distance
in
distances
]).
squeeze
()
return
sampled_points
def
draw_polyline_ego_on_img
(
polyline_ego
,
img_bgr
,
extrinsics
,
intrinsics
,
color_bgr
,
thickness
):
# if 2-dimension, assume z=0
if
polyline_ego
.
shape
[
1
]
==
2
:
zeros
=
np
.
zeros
((
polyline_ego
.
shape
[
0
],
1
))
polyline_ego
=
np
.
concatenate
([
polyline_ego
,
zeros
],
axis
=
1
)
polyline_ego
=
interp_fixed_dist
(
line
=
LineString
(
polyline_ego
),
sample_dist
=
0.2
)
uv
,
depth
=
points_ego2img
(
polyline_ego
,
extrinsics
,
intrinsics
)
h
,
w
,
c
=
img_bgr
.
shape
is_valid_x
=
np
.
logical_and
(
0
<=
uv
[:,
0
],
uv
[:,
0
]
<
w
-
1
)
is_valid_y
=
np
.
logical_and
(
0
<=
uv
[:,
1
],
uv
[:,
1
]
<
h
-
1
)
is_valid_z
=
depth
>
0
is_valid_points
=
np
.
logical_and
.
reduce
([
is_valid_x
,
is_valid_y
,
is_valid_z
])
if
is_valid_points
.
sum
()
==
0
:
return
tmp_list
=
[]
for
i
,
valid
in
enumerate
(
is_valid_points
):
if
valid
:
tmp_list
.
append
(
uv
[
i
])
else
:
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
)
tmp_list
=
[]
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2(
# copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr,
# color=color_bgr,
# thickness_px=thickness,
# )
def
draw_visible_polyline_cv2
(
line
,
valid_pts_bool
,
image
,
color
,
thickness_px
):
"""Draw a polyline onto an image using given line segments.
Args:
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
line
=
np
.
round
(
line
).
astype
(
int
)
# type: ignore
for
i
in
range
(
len
(
line
)
-
1
):
if
(
not
valid_pts_bool
[
i
])
or
(
not
valid_pts_bool
[
i
+
1
]):
continue
x1
=
line
[
i
][
0
]
y1
=
line
[
i
][
1
]
x2
=
line
[
i
+
1
][
0
]
y2
=
line
[
i
+
1
][
1
]
# Use anti-aliasing (AA) for curves
image
=
cv2
.
line
(
image
,
pt1
=
(
x1
,
y1
),
pt2
=
(
x2
,
y2
),
color
=
color
,
thickness
=
thickness_px
,
lineType
=
cv2
.
LINE_AA
)
COLOR_MAPS_BGR
=
{
# bgr colors
'divider'
:
(
0
,
0
,
255
),
'boundary'
:
(
0
,
255
,
0
),
'ped_crossing'
:
(
255
,
0
,
0
),
'centerline'
:
(
51
,
183
,
255
),
'drivable_area'
:
(
171
,
255
,
255
)
}
COLOR_MAPS_PLT
=
{
'divider'
:
'r'
,
'boundary'
:
'g'
,
'ped_crossing'
:
'b'
,
'centerline'
:
'orange'
,
'drivable_area'
:
'y'
,
}
CAM_NAMES_AV2
=
[
'ring_front_center'
,
'ring_front_right'
,
'ring_front_left'
,
'ring_rear_right'
,
'ring_rear_left'
,
'ring_side_right'
,
'ring_side_left'
,
]
class
Renderer
(
object
):
"""Render map elements on image views.
Args:
roi_size (tuple): bev range
"""
def
__init__
(
self
,
roi_size
):
self
.
roi_size
=
roi_size
def
render_bev_from_vectors
(
self
,
vectors
,
out_dir
):
'''Plot vectorized map elements on BEV.
Args:
vectors (dict): dict of vectorized map elements.
out_dir (str): output directory
'''
car_img
=
Image
.
open
(
'resources/images/car.png'
)
map_path
=
os
.
path
.
join
(
out_dir
,
'map.jpg'
)
plt
.
figure
(
figsize
=
(
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]))
plt
.
xlim
(
-
self
.
roi_size
[
0
]
/
2
-
1
,
self
.
roi_size
[
0
]
/
2
+
1
)
plt
.
ylim
(
-
self
.
roi_size
[
1
]
/
2
-
1
,
self
.
roi_size
[
1
]
/
2
+
1
)
plt
.
axis
(
'off'
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_PLT
[
cat
]
for
vector
in
vector_list
:
pts
=
np
.
array
(
vector
)[:,
:
2
]
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1)
plt
.
plot
(
x
,
y
,
color
=
color
,
linewidth
=
5
,
marker
=
'o'
,
linestyle
=
'-'
,
markersize
=
20
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
dpi
=
40
)
plt
.
close
()
def
render_camera_views_from_vectors
(
self
,
vectors
,
imgs
,
extrinsics
,
intrinsics
,
thickness
,
out_dir
):
'''Project vectorized map elements to camera views.
Args:
vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images.
out_dir (str): output directory
'''
for
i
in
range
(
len
(
imgs
)):
img
=
imgs
[
i
]
extrinsic
=
extrinsics
[
i
]
intrinsic
=
intrinsics
[
i
]
img_bgr
=
copy
.
deepcopy
(
img
)
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_BGR
[
cat
]
for
vector
in
vector_list
:
img_bgr
=
np
.
ascontiguousarray
(
img_bgr
)
vector_array
=
np
.
array
(
vector
)
if
vector_array
.
shape
[
1
]
>
3
:
vector_array
=
vector_array
[:,
:
3
]
draw_polyline_ego_on_img
(
vector_array
,
img_bgr
,
extrinsic
,
intrinsic
,
color
,
thickness
)
out_path
=
osp
.
join
(
out_dir
,
CAM_NAMES_AV2
[
i
])
+
'.jpg'
cv2
.
imwrite
(
out_path
,
img_bgr
)
import
os.path
as
osp
import
os
import
numpy
as
np
import
copy
import
cv2
import
matplotlib.pyplot
as
plt
from
PIL
import
Image
from
shapely.geometry
import
LineString
def
remove_nan_values
(
uv
):
is_u_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
0
]))
is_v_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
1
]))
is_uv_valid
=
np
.
logical_and
(
is_u_valid
,
is_v_valid
)
uv_valid
=
uv
[
is_uv_valid
]
return
uv_valid
def
points_ego2img
(
pts_ego
,
extrinsics
,
intrinsics
):
pts_ego_4d
=
np
.
concatenate
([
pts_ego
,
np
.
ones
([
len
(
pts_ego
),
1
])],
axis
=-
1
)
pts_cam_4d
=
extrinsics
@
pts_ego_4d
.
T
uv
=
(
intrinsics
@
pts_cam_4d
[:
3
,
:]).
T
uv
=
remove_nan_values
(
uv
)
depth
=
uv
[:,
2
]
uv
=
uv
[:,
:
2
]
/
uv
[:,
2
].
reshape
(
-
1
,
1
)
return
uv
,
depth
def
interp_fixed_dist
(
line
,
sample_dist
):
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances
=
list
(
np
.
arange
(
sample_dist
,
line
.
length
,
sample_dist
))
# make sure to sample at least two points when sample_dist > line.length
distances
=
[
0
,]
+
distances
+
[
line
.
length
,]
sampled_points
=
np
.
array
([
list
(
line
.
interpolate
(
distance
).
coords
)
for
distance
in
distances
]).
squeeze
()
return
sampled_points
def
draw_polyline_ego_on_img
(
polyline_ego
,
img_bgr
,
extrinsics
,
intrinsics
,
color_bgr
,
thickness
):
# if 2-dimension, assume z=0
if
polyline_ego
.
shape
[
1
]
==
2
:
zeros
=
np
.
zeros
((
polyline_ego
.
shape
[
0
],
1
))
polyline_ego
=
np
.
concatenate
([
polyline_ego
,
zeros
],
axis
=
1
)
polyline_ego
=
interp_fixed_dist
(
line
=
LineString
(
polyline_ego
),
sample_dist
=
0.2
)
uv
,
depth
=
points_ego2img
(
polyline_ego
,
extrinsics
,
intrinsics
)
h
,
w
,
c
=
img_bgr
.
shape
is_valid_x
=
np
.
logical_and
(
0
<=
uv
[:,
0
],
uv
[:,
0
]
<
w
-
1
)
is_valid_y
=
np
.
logical_and
(
0
<=
uv
[:,
1
],
uv
[:,
1
]
<
h
-
1
)
is_valid_z
=
depth
>
0
is_valid_points
=
np
.
logical_and
.
reduce
([
is_valid_x
,
is_valid_y
,
is_valid_z
])
if
is_valid_points
.
sum
()
==
0
:
return
tmp_list
=
[]
for
i
,
valid
in
enumerate
(
is_valid_points
):
if
valid
:
tmp_list
.
append
(
uv
[
i
])
else
:
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
)
tmp_list
=
[]
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2(
# copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr,
# color=color_bgr,
# thickness_px=thickness,
# )
def
draw_visible_polyline_cv2
(
line
,
valid_pts_bool
,
image
,
color
,
thickness_px
):
"""Draw a polyline onto an image using given line segments.
Args:
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
line
=
np
.
round
(
line
).
astype
(
int
)
# type: ignore
for
i
in
range
(
len
(
line
)
-
1
):
if
(
not
valid_pts_bool
[
i
])
or
(
not
valid_pts_bool
[
i
+
1
]):
continue
x1
=
line
[
i
][
0
]
y1
=
line
[
i
][
1
]
x2
=
line
[
i
+
1
][
0
]
y2
=
line
[
i
+
1
][
1
]
# Use anti-aliasing (AA) for curves
image
=
cv2
.
line
(
image
,
pt1
=
(
x1
,
y1
),
pt2
=
(
x2
,
y2
),
color
=
color
,
thickness
=
thickness_px
,
lineType
=
cv2
.
LINE_AA
)
COLOR_MAPS_BGR
=
{
# bgr colors
'divider'
:
(
0
,
0
,
255
),
'boundary'
:
(
0
,
255
,
0
),
'ped_crossing'
:
(
255
,
0
,
0
),
'centerline'
:
(
51
,
183
,
255
),
'drivable_area'
:
(
171
,
255
,
255
)
}
COLOR_MAPS_PLT
=
{
'divider'
:
'r'
,
'boundary'
:
'g'
,
'ped_crossing'
:
'b'
,
'centerline'
:
'orange'
,
'drivable_area'
:
'y'
,
}
CAM_NAMES_AV2
=
[
'ring_front_center'
,
'ring_front_right'
,
'ring_front_left'
,
'ring_rear_right'
,
'ring_rear_left'
,
'ring_side_right'
,
'ring_side_left'
,
]
class
Renderer
(
object
):
"""Render map elements on image views.
Args:
roi_size (tuple): bev range
"""
def
__init__
(
self
,
roi_size
):
self
.
roi_size
=
roi_size
def
render_bev_from_vectors
(
self
,
vectors
,
out_dir
):
'''Plot vectorized map elements on BEV.
Args:
vectors (dict): dict of vectorized map elements.
out_dir (str): output directory
'''
car_img
=
Image
.
open
(
'resources/images/car.png'
)
map_path
=
os
.
path
.
join
(
out_dir
,
'map.jpg'
)
plt
.
figure
(
figsize
=
(
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]))
plt
.
xlim
(
-
self
.
roi_size
[
0
]
/
2
-
1
,
self
.
roi_size
[
0
]
/
2
+
1
)
plt
.
ylim
(
-
self
.
roi_size
[
1
]
/
2
-
1
,
self
.
roi_size
[
1
]
/
2
+
1
)
plt
.
axis
(
'off'
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_PLT
[
cat
]
for
vector
in
vector_list
:
pts
=
np
.
array
(
vector
)[:,
:
2
]
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1)
plt
.
plot
(
x
,
y
,
color
=
color
,
linewidth
=
5
,
marker
=
'o'
,
linestyle
=
'-'
,
markersize
=
20
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
dpi
=
40
)
plt
.
close
()
def
render_camera_views_from_vectors
(
self
,
vectors
,
imgs
,
extrinsics
,
intrinsics
,
thickness
,
out_dir
):
'''Project vectorized map elements to camera views.
Args:
vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images.
out_dir (str): output directory
'''
for
i
in
range
(
len
(
imgs
)):
img
=
imgs
[
i
]
extrinsic
=
extrinsics
[
i
]
intrinsic
=
intrinsics
[
i
]
img_bgr
=
copy
.
deepcopy
(
img
)
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_BGR
[
cat
]
for
vector
in
vector_list
:
img_bgr
=
np
.
ascontiguousarray
(
img_bgr
)
vector_array
=
np
.
array
(
vector
)
if
vector_array
.
shape
[
1
]
>
3
:
vector_array
=
vector_array
[:,
:
3
]
draw_polyline_ego_on_img
(
vector_array
,
img_bgr
,
extrinsic
,
intrinsic
,
color
,
thickness
)
out_path
=
osp
.
join
(
out_dir
,
CAM_NAMES_AV2
[
i
])
+
'.jpg'
cv2
.
imwrite
(
out_path
,
img_bgr
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment