Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
f3b13cad
Commit
f3b13cad
authored
May 17, 2023
by
yeshenglong1
Browse files
UpDate README.md
parent
0797920d
Changes
102
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3917 additions
and
3917 deletions
+3917
-3917
autonomous_driving/Online-HD-Map-Construction/src/models/heads/dg_head.py
...ng/Online-HD-Map-Construction/src/models/heads/dg_head.py
+305
-305
autonomous_driving/Online-HD-Map-Construction/src/models/heads/map_element_detector.py
...Map-Construction/src/models/heads/map_element_detector.py
+627
-627
autonomous_driving/Online-HD-Map-Construction/src/models/heads/polyline_generator.py
...D-Map-Construction/src/models/heads/polyline_generator.py
+525
-525
autonomous_driving/Online-HD-Map-Construction/src/models/losses/__init__.py
.../Online-HD-Map-Construction/src/models/losses/__init__.py
+2
-2
autonomous_driving/Online-HD-Map-Construction/src/models/losses/detr_loss.py
...Online-HD-Map-Construction/src/models/losses/detr_loss.py
+169
-169
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/__init__.py
.../Online-HD-Map-Construction/src/models/mapers/__init__.py
+1
-1
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/base_mapper.py
...line-HD-Map-Construction/src/models/mapers/base_mapper.py
+148
-148
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/vectormapnet.py
...ine-HD-Map-Construction/src/models/mapers/vectormapnet.py
+260
-260
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/__init__.py
...Map-Construction/src/models/transformer_utils/__init__.py
+1
-1
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/base_transformer.py
...truction/src/models/transformer_utils/base_transformer.py
+24
-24
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/deformable_transformer.py
...on/src/models/transformer_utils/deformable_transformer.py
+368
-368
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/fp16_dattn.py
...p-Construction/src/models/transformer_utils/fp16_dattn.py
+401
-401
autonomous_driving/Online-HD-Map-Construction/tools/dist_test.sh
...ous_driving/Online-HD-Map-Construction/tools/dist_test.sh
+10
-10
autonomous_driving/Online-HD-Map-Construction/tools/dist_train.sh
...us_driving/Online-HD-Map-Construction/tools/dist_train.sh
+9
-9
autonomous_driving/Online-HD-Map-Construction/tools/evaluate_submission.py
...g/Online-HD-Map-Construction/tools/evaluate_submission.py
+28
-28
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_test.py
...us_driving/Online-HD-Map-Construction/tools/mmdet_test.py
+190
-190
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_train.py
...s_driving/Online-HD-Map-Construction/tools/mmdet_train.py
+170
-170
autonomous_driving/Online-HD-Map-Construction/tools/test.py
autonomous_driving/Online-HD-Map-Construction/tools/test.py
+196
-196
autonomous_driving/Online-HD-Map-Construction/tools/train.py
autonomous_driving/Online-HD-Map-Construction/tools/train.py
+261
-261
autonomous_driving/Online-HD-Map-Construction/tools/visualization/renderer.py
...nline-HD-Map-Construction/tools/visualization/renderer.py
+222
-222
No files found.
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/dg_head.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/dg_head.py
View file @
f3b13cad
import
copy
import
copy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
mmdet.models
import
HEADS
,
build_head
,
build_loss
from
mmdet.models
import
HEADS
,
build_head
,
build_loss
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
.base_map_head
import
BaseMapHead
from
.base_map_head
import
BaseMapHead
import
numpy
as
np
import
numpy
as
np
from
..augmentation.sythesis_det
import
NoiseSythesis
from
..augmentation.sythesis_det
import
NoiseSythesis
@
HEADS
.
register_module
(
force
=
True
)
@
HEADS
.
register_module
(
force
=
True
)
class
DGHead
(
BaseMapHead
):
class
DGHead
(
BaseMapHead
):
def
__init__
(
self
,
def
__init__
(
self
,
det_net_cfg
=
dict
(),
det_net_cfg
=
dict
(),
gen_net_cfg
=
dict
(),
gen_net_cfg
=
dict
(),
loss_vert
=
dict
(),
loss_vert
=
dict
(),
loss_face
=
dict
(),
loss_face
=
dict
(),
max_num_vertices
=
90
,
max_num_vertices
=
90
,
top_p_gen_model
=
0.9
,
top_p_gen_model
=
0.9
,
sync_cls_avg_factor
=
True
,
sync_cls_avg_factor
=
True
,
augmentation
=
False
,
augmentation
=
False
,
augmentation_kwargs
=
None
,
augmentation_kwargs
=
None
,
joint_training
=
False
,
joint_training
=
False
,
**
kwargs
):
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
# Heads
# Heads
self
.
det_net
=
build_head
(
det_net_cfg
)
self
.
det_net
=
build_head
(
det_net_cfg
)
self
.
gen_net
=
build_head
(
gen_net_cfg
)
self
.
gen_net
=
build_head
(
gen_net_cfg
)
self
.
coord_dim
=
self
.
gen_net
.
coord_dim
self
.
coord_dim
=
self
.
gen_net
.
coord_dim
# Loss params
# Loss params
self
.
bg_cls_weight
=
1.0
self
.
bg_cls_weight
=
1.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
max_num_vertices
=
max_num_vertices
self
.
max_num_vertices
=
max_num_vertices
self
.
top_p_gen_model
=
top_p_gen_model
self
.
top_p_gen_model
=
top_p_gen_model
self
.
fp16_enabled
=
False
self
.
fp16_enabled
=
False
self
.
augmentation
=
None
self
.
augmentation
=
None
if
augmentation
:
if
augmentation
:
augmentation_kwargs
.
update
({
'canvas_size'
:
gen_net_cfg
.
canvas_size
})
augmentation_kwargs
.
update
({
'canvas_size'
:
gen_net_cfg
.
canvas_size
})
self
.
augmentation
=
NoiseSythesis
(
**
augmentation_kwargs
)
self
.
augmentation
=
NoiseSythesis
(
**
augmentation_kwargs
)
self
.
joint_training
=
joint_training
self
.
joint_training
=
joint_training
def
forward
(
self
,
batch
,
img_metas
=
None
,
**
kwargs
):
def
forward
(
self
,
batch
,
img_metas
=
None
,
**
kwargs
):
'''
'''
Args:
Args:
Returns:
Returns:
outs (Dict):
outs (Dict):
'''
'''
if
self
.
training
:
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
return
self
.
inference
(
batch
,
**
kwargs
)
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
only_det
=
False
,
**
kwargs
):
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
only_det
=
False
,
**
kwargs
):
''' we use teacher force strategy'''
''' we use teacher force strategy'''
bbox_dict
=
self
.
det_net
(
context
=
context
)
bbox_dict
=
self
.
det_net
(
context
=
context
)
outs
=
dict
(
outs
=
dict
(
bbox
=
bbox_dict
,
bbox
=
bbox_dict
,
)
)
losses_dict
,
det_match_idxs
,
det_match_gt_idxs
=
\
losses_dict
,
det_match_idxs
,
det_match_gt_idxs
=
\
self
.
loss_det
(
batch
,
outs
)
self
.
loss_det
(
batch
,
outs
)
if
only_det
:
return
outs
,
losses_dict
if
only_det
:
return
outs
,
losses_dict
if
self
.
augmentation
is
not
None
:
if
self
.
augmentation
is
not
None
:
polylines
,
bbox_flat
=
\
polylines
,
bbox_flat
=
\
self
.
augmentation
(
batch
[
'gen'
],
simple_aug
=
True
)
self
.
augmentation
(
batch
[
'gen'
],
simple_aug
=
True
)
if
bbox_flat
is
None
:
if
bbox_flat
is
None
:
bbox_flat
=
batch
[
'gen'
][
'bbox_flat'
]
bbox_flat
=
batch
[
'gen'
][
'bbox_flat'
]
gen_input
=
dict
(
gen_input
=
dict
(
lines_bs_idx
=
batch
[
'gen'
][
'lines_bs_idx'
],
lines_bs_idx
=
batch
[
'gen'
][
'lines_bs_idx'
],
lines_cls
=
batch
[
'gen'
][
'lines_cls'
],
lines_cls
=
batch
[
'gen'
][
'lines_cls'
],
bbox_flat
=
bbox_flat
,
bbox_flat
=
bbox_flat
,
polylines
=
polylines
,
polylines
=
polylines
,
polyline_masks
=
batch
[
'gen'
][
'polyline_masks'
]
polyline_masks
=
batch
[
'gen'
][
'polyline_masks'
]
)
)
else
:
else
:
gen_input
=
batch
[
'gen'
]
gen_input
=
batch
[
'gen'
]
if
self
.
joint_training
:
if
self
.
joint_training
:
# for down stream polyline
# for down stream polyline
if
'lines'
in
bbox_dict
[
-
1
]:
if
'lines'
in
bbox_dict
[
-
1
]:
# for fix anchor
# for fix anchor
pred_bbox
=
bbox_dict
[
-
1
][
'lines'
].
detach
()
pred_bbox
=
bbox_dict
[
-
1
][
'lines'
].
detach
()
elif
'bboxs'
in
bbox_dict
[
-
1
]:
elif
'bboxs'
in
bbox_dict
[
-
1
]:
# for rpv
# for rpv
pred_bbox
=
bbox_dict
[
-
1
][
'bboxs'
].
detach
()
pred_bbox
=
bbox_dict
[
-
1
][
'bboxs'
].
detach
()
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# changed to original gt order.
# changed to original gt order.
det_match_idx
=
det_match_idxs
[
-
1
]
det_match_idx
=
det_match_idxs
[
-
1
]
det_match_gt_idx
=
det_match_gt_idxs
[
-
1
]
det_match_gt_idx
=
det_match_gt_idxs
[
-
1
]
_bboxs
=
[]
_bboxs
=
[]
for
i
,
(
match_idx
,
bbox
)
in
enumerate
(
zip
(
det_match_idx
,
pred_bbox
)):
for
i
,
(
match_idx
,
bbox
)
in
enumerate
(
zip
(
det_match_idx
,
pred_bbox
)):
_bboxs
.
append
(
bbox
[
match_idx
])
_bboxs
.
append
(
bbox
[
match_idx
])
_bboxs
[
-
1
]
=
_bboxs
[
-
1
][
torch
.
argsort
(
det_match_gt_idx
[
i
])]
_bboxs
[
-
1
]
=
_bboxs
[
-
1
][
torch
.
argsort
(
det_match_gt_idx
[
i
])]
_bboxs
=
torch
.
cat
(
_bboxs
,
dim
=
0
)
_bboxs
=
torch
.
cat
(
_bboxs
,
dim
=
0
)
# quantize the data
# quantize the data
_bboxs
=
\
_bboxs
=
\
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
# gen_input['bbox_flat'] = _bboxs
# gen_input['bbox_flat'] = _bboxs
remain_idx
=
torch
.
randperm
(
_bboxs
.
shape
[
0
])[:
int
(
_bboxs
.
shape
[
0
]
*
0.2
)]
remain_idx
=
torch
.
randperm
(
_bboxs
.
shape
[
0
])[:
int
(
_bboxs
.
shape
[
0
]
*
0.2
)]
# for data efficient
# for data efficient
for
k
in
gen_input
.
keys
():
for
k
in
gen_input
.
keys
():
if
k
==
'bbox_flat'
:
if
k
==
'bbox_flat'
:
gen_input
[
k
]
=
torch
.
cat
((
_bboxs
,
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
gen_input
[
k
]
=
torch
.
cat
((
_bboxs
,
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
else
:
else
:
gen_input
[
k
]
=
torch
.
cat
((
gen_input
[
k
],
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
gen_input
[
k
]
=
torch
.
cat
((
gen_input
[
k
],
gen_input
[
k
][
remain_idx
]),
dim
=
0
)
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
gen_input
,
context
=
context
)
poly_dict
=
self
.
gen_net
(
gen_input
,
context
=
context
)
outs
.
update
(
dict
(
outs
.
update
(
dict
(
polylines
=
poly_dict
,
polylines
=
poly_dict
,
))
))
if
self
.
joint_training
:
if
self
.
joint_training
:
for
k
in
batch
[
'gen'
].
keys
():
for
k
in
batch
[
'gen'
].
keys
():
batch
[
'gen'
][
k
]
=
\
batch
[
'gen'
][
k
]
=
\
torch
.
cat
((
batch
[
'gen'
][
k
],
batch
[
'gen'
][
k
][
remain_idx
]),
dim
=
0
)
torch
.
cat
((
batch
[
'gen'
][
k
],
batch
[
'gen'
][
k
][
remain_idx
]),
dim
=
0
)
gen_losses_dict
=
\
gen_losses_dict
=
\
self
.
loss_gen
(
batch
,
outs
)
self
.
loss_gen
(
batch
,
outs
)
losses_dict
.
update
(
gen_losses_dict
)
losses_dict
.
update
(
gen_losses_dict
)
return
outs
,
losses_dict
return
outs
,
losses_dict
def
loss_det
(
self
,
gt
:
dict
,
pred
:
dict
):
def
loss_det
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
loss_dict
=
{}
# det
# det
det_loss_dict
,
det_match_idx
,
det_match_gt_idx
=
\
det_loss_dict
,
det_match_idx
,
det_match_gt_idx
=
\
self
.
det_net
.
loss
(
gt
[
'det'
],
pred
[
'bbox'
])
self
.
det_net
.
loss
(
gt
[
'det'
],
pred
[
'bbox'
])
for
k
,
v
in
det_loss_dict
.
items
():
for
k
,
v
in
det_loss_dict
.
items
():
loss_dict
[
'det_'
+
k
]
=
v
loss_dict
[
'det_'
+
k
]
=
v
return
loss_dict
,
det_match_idx
,
det_match_gt_idx
return
loss_dict
,
det_match_idx
,
det_match_gt_idx
def
loss_gen
(
self
,
gt
:
dict
,
pred
:
dict
):
def
loss_gen
(
self
,
gt
:
dict
,
pred
:
dict
):
loss_dict
=
{}
loss_dict
=
{}
# gen
# gen
gen_loss_dict
=
self
.
gen_net
.
loss
(
gt
[
'gen'
],
pred
[
'polylines'
])
gen_loss_dict
=
self
.
gen_net
.
loss
(
gt
[
'gen'
],
pred
[
'polylines'
])
for
k
,
v
in
gen_loss_dict
.
items
():
for
k
,
v
in
gen_loss_dict
.
items
():
loss_dict
[
'gen_'
+
k
]
=
v
loss_dict
[
'gen_'
+
k
]
=
v
return
loss_dict
return
loss_dict
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
pass
pass
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
inference
(
self
,
batch
:
dict
=
{},
context
:
dict
=
{},
gt_condition
=
False
,
**
kwargs
):
def
inference
(
self
,
batch
:
dict
=
{},
context
:
dict
=
{},
gt_condition
=
False
,
**
kwargs
):
'''
'''
num_samples_batch: number of sample per batch (batch size)
num_samples_batch: number of sample per batch (batch size)
'''
'''
outs
=
{}
outs
=
{}
bbox_dict
=
self
.
det_net
(
context
=
context
)
bbox_dict
=
self
.
det_net
(
context
=
context
)
bbox_dict
=
self
.
det_net
.
post_process
(
bbox_dict
)
bbox_dict
=
self
.
det_net
.
post_process
(
bbox_dict
)
outs
.
update
(
bbox_dict
)
outs
.
update
(
bbox_dict
)
if
len
(
outs
[
'lines_bs_idx'
])
==
0
:
if
len
(
outs
[
'lines_bs_idx'
])
==
0
:
return
None
return
None
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
if
isinstance
(
context
[
'bev_embeddings'
],
tuple
):
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
context
[
'bev_embeddings'
]
=
context
[
'bev_embeddings'
][
0
]
poly_dict
=
self
.
gen_net
(
outs
,
poly_dict
=
self
.
gen_net
(
outs
,
context
=
context
,
context
=
context
,
# max_sample_length=self.max_num_vertices,
# max_sample_length=self.max_num_vertices,
max_sample_length
=
64
,
max_sample_length
=
64
,
top_p
=
self
.
top_p_gen_model
,
top_p
=
self
.
top_p_gen_model
,
gt_condition
=
gt_condition
)
gt_condition
=
gt_condition
)
outs
.
update
(
poly_dict
)
outs
.
update
(
poly_dict
)
return
outs
return
outs
def
post_process
(
self
,
preds
:
dict
,
tokens
,
gts
:
dict
=
None
,
**
kwargs
):
def
post_process
(
self
,
preds
:
dict
,
tokens
,
gts
:
dict
=
None
,
**
kwargs
):
'''
'''
Args:
Args:
XXX
XXX
Outs:
Outs:
XXX
XXX
'''
'''
range_size
=
self
.
gen_net
.
canvas_size
.
cpu
().
numpy
()
range_size
=
self
.
gen_net
.
canvas_size
.
cpu
().
numpy
()
coord_dim
=
self
.
gen_net
.
coord_dim
coord_dim
=
self
.
gen_net
.
coord_dim
gen_net_name
=
self
.
gen_net
.
name
if
hasattr
(
self
.
gen_net
,
'name'
)
else
'gen'
gen_net_name
=
self
.
gen_net
.
name
if
hasattr
(
self
.
gen_net
,
'name'
)
else
'gen'
ret_list
=
[]
ret_list
=
[]
for
batch_idx
in
range
(
len
(
tokens
)):
for
batch_idx
in
range
(
len
(
tokens
)):
ret_dict_single
=
{}
ret_dict_single
=
{}
# bbox
# bbox
det_gt
=
None
det_gt
=
None
if
gts
is
not
None
:
if
gts
is
not
None
:
det_gt
,
rec_groundtruth
=
pack_groundtruth
(
det_gt
,
rec_groundtruth
=
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
,
coord_dim
=
coord_dim
)
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
,
coord_dim
=
coord_dim
)
bbox_res
=
{
bbox_res
=
{
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
# 'det_gt': det_gt,
'token'
:
tokens
[
batch_idx
],
'token'
:
tokens
[
batch_idx
],
'scores'
:
preds
[
'scores'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'scores'
:
preds
[
'scores'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'labels'
:
preds
[
'labels'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'labels'
:
preds
[
'labels'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
}
}
ret_dict_single
.
update
(
bbox_res
)
ret_dict_single
.
update
(
bbox_res
)
# for gen results.
# for gen results.
batch2seq
=
np
.
nonzero
(
batch2seq
=
np
.
nonzero
(
preds
[
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
preds
[
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_dict_single
.
update
({
ret_dict_single
.
update
({
'nline'
:
len
(
batch2seq
),
'nline'
:
len
(
batch2seq
),
'vectors'
:
[]
'vectors'
:
[]
})
})
for
i
in
batch2seq
:
for
i
in
batch2seq
:
pre
=
preds
[
'polylines'
][
i
].
detach
().
cpu
().
numpy
()
pre
=
preds
[
'polylines'
][
i
].
detach
().
cpu
().
numpy
()
pre_msk
=
preds
[
'polyline_masks'
][
i
].
detach
().
cpu
().
numpy
()
pre_msk
=
preds
[
'polyline_masks'
][
i
].
detach
().
cpu
().
numpy
()
valid_idx
=
np
.
nonzero
(
pre_msk
)[
0
][:
-
1
]
valid_idx
=
np
.
nonzero
(
pre_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
# From [200,1] to [199,0] to (1,0)
line
=
(
pre
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
line
=
(
pre
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_dict_single
[
'vectors'
].
append
(
line
)
ret_dict_single
[
'vectors'
].
append
(
line
)
# if gts is not None:
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list
.
append
(
ret_dict_single
)
ret_list
.
append
(
ret_dict_single
)
return
ret_list
return
ret_list
def
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
=
'gen'
,
coord_dim
=
2
):
def
pack_groundtruth
(
batch_idx
,
gts
,
tokens
,
range_size
,
gen_net_name
=
'gen'
,
coord_dim
=
2
):
if
'keypoints'
in
gts
[
'det'
]:
if
'keypoints'
in
gts
[
'det'
]:
gt_bbox
=
\
gt_bbox
=
\
gts
[
'det'
][
'keypoints'
][
batch_idx
].
detach
().
cpu
().
numpy
()
gts
[
'det'
][
'keypoints'
][
batch_idx
].
detach
().
cpu
().
numpy
()
else
:
else
:
gt_bbox
=
\
gt_bbox
=
\
gts
[
'det'
][
'bbox'
][
batch_idx
].
detach
().
cpu
().
numpy
()
gts
[
'det'
][
'bbox'
][
batch_idx
].
detach
().
cpu
().
numpy
()
det_gt
=
{
det_gt
=
{
'labels'
:
gts
[
'det'
][
'class_label'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'labels'
:
gts
[
'det'
][
'class_label'
][
batch_idx
].
detach
().
cpu
().
numpy
(),
'bboxes'
:
gt_bbox
,
'bboxes'
:
gt_bbox
,
}
}
batch2seq
=
np
.
nonzero
(
batch2seq
=
np
.
nonzero
(
gts
[
'gen'
][
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
gts
[
'gen'
][
'lines_bs_idx'
].
cpu
().
numpy
()
==
batch_idx
)[
0
]
ret_groundtruth
=
{
ret_groundtruth
=
{
'token'
:
tokens
[
batch_idx
],
'token'
:
tokens
[
batch_idx
],
'nline'
:
len
(
batch2seq
),
'nline'
:
len
(
batch2seq
),
'labels'
:
gts
[
'gen'
][
'lines_cls'
][
batch2seq
].
detach
().
cpu
().
numpy
(),
'labels'
:
gts
[
'gen'
][
'lines_cls'
][
batch2seq
].
detach
().
cpu
().
numpy
(),
'lines'
:
[],
'lines'
:
[],
}
}
for
i
in
batch2seq
:
for
i
in
batch2seq
:
gt_line
=
\
gt_line
=
\
gts
[
'gen'
][
'polylines'
].
detach
().
cpu
().
numpy
()[
i
]
gts
[
'gen'
][
'polylines'
].
detach
().
cpu
().
numpy
()[
i
]
gt_msk
=
gts
[
'gen'
][
'polyline_masks'
].
detach
().
cpu
().
numpy
()[
i
]
gt_msk
=
gts
[
'gen'
][
'polyline_masks'
].
detach
().
cpu
().
numpy
()[
i
]
if
gen_net_name
==
'gen_gmm'
:
if
gen_net_name
==
'gen_gmm'
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
]
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
]
else
:
else
:
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
][:
-
1
]
valid_idx
=
np
.
nonzero
(
gt_msk
)[
0
][:
-
1
]
# From [200,1] to [199,0] to (1,0)
# From [200,1] to [199,0] to (1,0)
line
=
(
gt_line
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
line
=
(
gt_line
[
valid_idx
].
reshape
(
-
1
,
coord_dim
)
-
1
)
/
(
range_size
-
1
)
ret_groundtruth
[
'lines'
].
append
(
line
)
ret_groundtruth
[
'lines'
].
append
(
line
)
return
det_gt
,
ret_groundtruth
return
det_gt
,
ret_groundtruth
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/map_element_detector.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/map_element_detector.py
View file @
f3b13cad
import
copy
import
copy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Conv2d
,
Linear
from
mmcv.cnn
import
Conv2d
,
Linear
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
torch.distributions.categorical
import
Categorical
from
torch.distributions.categorical
import
Categorical
from
mmdet.core
import
(
multi_apply
,
build_assigner
,
build_sampler
,
from
mmdet.core
import
(
multi_apply
,
build_assigner
,
build_sampler
,
reduce_mean
)
reduce_mean
)
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.detr_bbox
import
DETRBboxHead
from
.detr_bbox
import
DETRBboxHead
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
build_loss
from
mmdet.models
import
build_loss
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
bias_init_with_prob
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
bias_init_with_prob
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils
import
build_transformer
@
HEADS
.
register_module
(
force
=
True
)
@
HEADS
.
register_module
(
force
=
True
)
class
MapElementDetector
(
nn
.
Module
):
class
MapElementDetector
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
canvas_size
=
(
400
,
200
),
canvas_size
=
(
400
,
200
),
discrete_output
=
False
,
discrete_output
=
False
,
separate_detect
=
False
,
separate_detect
=
False
,
mode
=
'xyxy'
,
mode
=
'xyxy'
,
bbox_size
=
None
,
bbox_size
=
None
,
coord_dim
=
2
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
kp_coord_dim
=
2
,
num_classes
=
3
,
num_classes
=
3
,
in_channels
=
128
,
in_channels
=
128
,
num_query
=
100
,
num_query
=
100
,
max_lines
=
50
,
max_lines
=
50
,
score_thre
=
0.2
,
score_thre
=
0.2
,
num_reg_fcs
=
2
,
num_reg_fcs
=
2
,
num_points
=
100
,
num_points
=
100
,
iterative
=
False
,
iterative
=
False
,
patch_size
=
None
,
patch_size
=
None
,
sync_cls_avg_factor
=
True
,
sync_cls_avg_factor
=
True
,
transformer
:
dict
=
None
,
transformer
:
dict
=
None
,
positional_encoding
:
dict
=
None
,
positional_encoding
:
dict
=
None
,
loss_cls
:
dict
=
None
,
loss_cls
:
dict
=
None
,
loss_reg
:
dict
=
None
,
loss_reg
:
dict
=
None
,
train_cfg
:
dict
=
None
,):
train_cfg
:
dict
=
None
,):
super
().
__init__
()
super
().
__init__
()
assigner
=
train_cfg
[
'assigner'
]
assigner
=
train_cfg
[
'assigner'
]
self
.
assigner
=
build_assigner
(
assigner
)
self
.
assigner
=
build_assigner
(
assigner
)
# DETR sampling=False, so use PseudoSampler
# DETR sampling=False, so use PseudoSampler
sampler_cfg
=
dict
(
type
=
'PseudoSampler'
)
sampler_cfg
=
dict
(
type
=
'PseudoSampler'
)
self
.
sampler
=
build_sampler
(
sampler_cfg
,
context
=
self
)
self
.
sampler
=
build_sampler
(
sampler_cfg
,
context
=
self
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
max_lines
=
max_lines
self
.
max_lines
=
max_lines
self
.
score_thre
=
score_thre
self
.
score_thre
=
score_thre
self
.
num_query
=
num_query
self
.
num_query
=
num_query
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
num_points
=
num_points
self
.
num_points
=
num_points
# branch
# branch
# if loss_cls.use_sigmoid:
# if loss_cls.use_sigmoid:
if
loss_cls
[
'use_sigmoid'
]:
if
loss_cls
[
'use_sigmoid'
]:
self
.
cls_out_channels
=
num_classes
self
.
cls_out_channels
=
num_classes
else
:
else
:
self
.
cls_out_channels
=
num_classes
+
1
self
.
cls_out_channels
=
num_classes
+
1
self
.
iterative
=
iterative
self
.
iterative
=
iterative
self
.
num_reg_fcs
=
num_reg_fcs
self
.
num_reg_fcs
=
num_reg_fcs
self
.
_build_transformer
(
transformer
,
positional_encoding
)
self
.
_build_transformer
(
transformer
,
positional_encoding
)
# loss params
# loss params
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
bg_cls_weight
=
0.1
self
.
bg_cls_weight
=
0.1
if
self
.
loss_cls
.
use_sigmoid
:
if
self
.
loss_cls
.
use_sigmoid
:
self
.
bg_cls_weight
=
0.0
self
.
bg_cls_weight
=
0.0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
reg_loss
=
build_loss
(
loss_reg
)
self
.
reg_loss
=
build_loss
(
loss_reg
)
self
.
separate_detect
=
separate_detect
self
.
separate_detect
=
separate_detect
self
.
discrete_output
=
discrete_output
self
.
discrete_output
=
discrete_output
self
.
bbox_size
=
3
if
mode
==
'sce'
else
2
self
.
bbox_size
=
3
if
mode
==
'sce'
else
2
if
bbox_size
is
not
None
:
if
bbox_size
is
not
None
:
self
.
bbox_size
=
bbox_size
self
.
bbox_size
=
bbox_size
self
.
coord_dim
=
coord_dim
# for xyz
self
.
coord_dim
=
coord_dim
# for xyz
self
.
kp_coord_dim
=
kp_coord_dim
self
.
kp_coord_dim
=
kp_coord_dim
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# add reg, cls head for each decoder layer
# add reg, cls head for each decoder layer
self
.
_init_layers
()
self
.
_init_layers
()
self
.
_init_branch
()
self
.
_init_branch
()
self
.
init_weights
()
self
.
init_weights
()
self
.
_init_embedding
()
self
.
_init_embedding
()
def
_init_layers
(
self
):
def
_init_layers
(
self
):
"""Initialize some layer."""
"""Initialize some layer."""
self
.
input_proj
=
Conv2d
(
self
.
input_proj
=
Conv2d
(
self
.
in_channels
,
self
.
embed_dims
,
kernel_size
=
1
)
self
.
in_channels
,
self
.
embed_dims
,
kernel_size
=
1
)
# query_pos_embed & query_embed
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
)
self
.
embed_dims
)
def
_init_embedding
(
self
):
def
_init_embedding
(
self
):
self
.
label_embed
=
nn
.
Embedding
(
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embed_dims
)
self
.
num_classes
,
self
.
embed_dims
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embed_dims
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embed_dims
)
# query_pos_embed & query_embed
# query_pos_embed & query_embed
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
self
.
embed_dims
*
2
)
# for bbox parameter xstart, ystart, xend, yend
# for bbox parameter xstart, ystart, xend, yend
self
.
bbox_embedding
=
nn
.
Embedding
(
self
.
bbox_size
,
self
.
bbox_embedding
=
nn
.
Embedding
(
self
.
bbox_size
,
self
.
embed_dims
*
2
)
self
.
embed_dims
*
2
)
def
_init_branch
(
self
,):
def
_init_branch
(
self
,):
"""Initialize classification branch and regression branch of head."""
"""Initialize classification branch and regression branch of head."""
fc_cls
=
Linear
(
self
.
embed_dims
*
self
.
bbox_size
,
self
.
cls_out_channels
)
fc_cls
=
Linear
(
self
.
embed_dims
*
self
.
bbox_size
,
self
.
cls_out_channels
)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch
=
[]
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
reg_branch
.
append
(
nn
.
ReLU
())
if
self
.
discrete_output
:
if
self
.
discrete_output
:
reg_branch
.
append
(
nn
.
Linear
(
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
max
(
self
.
canvas_size
),
bias
=
True
,))
self
.
embed_dims
,
max
(
self
.
canvas_size
),
bias
=
True
,))
else
:
else
:
reg_branch
.
append
(
nn
.
Linear
(
reg_branch
.
append
(
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
,
bias
=
True
,))
self
.
embed_dims
,
self
.
coord_dim
,
bias
=
True
,))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
# add sigmoid or not
# add sigmoid or not
def
_get_clones
(
module
,
N
):
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
num_pred
=
self
.
transformer
.
decoder
.
num_layers
num_pred
=
self
.
transformer
.
decoder
.
num_layers
if
self
.
iterative
:
if
self
.
iterative
:
fc_cls
=
_get_clones
(
fc_cls
,
num_pred
)
fc_cls
=
_get_clones
(
fc_cls
,
num_pred
)
reg_branch
=
_get_clones
(
reg_branch
,
num_pred
)
reg_branch
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
else
:
reg_branch
=
nn
.
ModuleList
(
reg_branch
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
[
reg_branch
for
_
in
range
(
num_pred
)])
fc_cls
=
nn
.
ModuleList
(
fc_cls
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
pre_branches
=
nn
.
ModuleDict
([
self
.
pre_branches
=
nn
.
ModuleDict
([
(
'cls'
,
fc_cls
),
(
'cls'
,
fc_cls
),
(
'reg'
,
reg_branch
),
])
(
'reg'
,
reg_branch
),
])
def
init_weights
(
self
):
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
"""Initialize weights of the DeformDETR head."""
for
p
in
self
.
input_proj
.
parameters
():
for
p
in
self
.
input_proj
.
parameters
():
if
p
.
dim
()
>
1
:
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
nn
.
init
.
xavier_uniform_
(
p
)
self
.
transformer
.
init_weights
()
self
.
transformer
.
init_weights
()
# init prediction branch
# init prediction branch
for
k
,
v
in
self
.
pre_branches
.
items
():
for
k
,
v
in
self
.
pre_branches
.
items
():
for
param
in
v
.
parameters
():
for
param
in
v
.
parameters
():
if
param
.
dim
()
>
1
:
if
param
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
param
)
nn
.
init
.
xavier_uniform_
(
param
)
# focal loss init
# focal loss init
if
self
.
loss_cls
.
use_sigmoid
:
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
bias_init
=
bias_init_with_prob
(
0.01
)
# for last layer
# for last layer
if
isinstance
(
self
.
pre_branches
[
'cls'
],
nn
.
ModuleList
):
if
isinstance
(
self
.
pre_branches
[
'cls'
],
nn
.
ModuleList
):
for
m
in
self
.
pre_branches
[
'cls'
]:
for
m
in
self
.
pre_branches
[
'cls'
]:
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
else
:
else
:
m
=
self
.
pre_branches
[
'cls'
]
m
=
self
.
pre_branches
[
'cls'
]
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
nn
.
init
.
constant_
(
m
.
bias
,
bias_init
)
def
_build_transformer
(
self
,
transformer
,
positional_encoding
):
def
_build_transformer
(
self
,
transformer
,
positional_encoding
):
# transformer
# transformer
self
.
act_cfg
=
transformer
.
get
(
'act_cfg'
,
self
.
act_cfg
=
transformer
.
get
(
'act_cfg'
,
dict
(
type
=
'ReLU'
,
inplace
=
True
))
dict
(
type
=
'ReLU'
,
inplace
=
True
))
self
.
activate
=
build_activation_layer
(
self
.
act_cfg
)
self
.
activate
=
build_activation_layer
(
self
.
act_cfg
)
self
.
positional_encoding
=
build_positional_encoding
(
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
positional_encoding
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
embed_dims
=
self
.
transformer
.
embed_dims
self
.
embed_dims
=
self
.
transformer
.
embed_dims
def
_prepare_context
(
self
,
context
):
def
_prepare_context
(
self
,
context
):
"""Prepare class label and vertex context."""
"""Prepare class label and vertex context."""
global_context_embedding
=
None
global_context_embedding
=
None
image_embeddings
=
context
[
'bev_embeddings'
]
image_embeddings
=
context
[
'bev_embeddings'
]
image_embeddings
=
self
.
input_proj
(
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
# only change feature size
image_embeddings
)
# only change feature size
# Pass images through encoder
# Pass images through encoder
device
=
image_embeddings
.
device
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
# Add 2D coordinate grid embedding
B
,
C
,
H
,
W
=
image_embeddings
.
shape
B
,
C
,
H
,
W
=
image_embeddings
.
shape
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
# Reshape spatial grid to sequence
sequential_context_embeddings
=
image_embeddings
.
reshape
(
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
C
,
H
,
W
)
B
,
C
,
H
,
W
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
return
(
global_context_embedding
,
sequential_context_embeddings
)
def
forward
(
self
,
context
,
img_metas
=
None
,
multi_scale
=
False
):
def
forward
(
self
,
context
,
img_metas
=
None
,
multi_scale
=
False
):
'''
'''
Args:
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
feature in bev view
img_metas
img_metas
Outs:
Outs:
preds_dict (Dict):
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
[nb_dec, bs, num_query, num_points, 2].
'''
'''
(
global_context_embedding
,
sequential_context_embeddings
)
=
\
(
global_context_embedding
,
sequential_context_embeddings
)
=
\
self
.
_prepare_context
(
context
)
self
.
_prepare_context
(
context
)
x
=
sequential_context_embeddings
x
=
sequential_context_embeddings
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
query_embedding
=
self
.
query_embedding
.
weight
[
None
,:,
None
].
repeat
(
B
,
1
,
self
.
bbox_size
,
1
)
query_embedding
=
self
.
query_embedding
.
weight
[
None
,:,
None
].
repeat
(
B
,
1
,
self
.
bbox_size
,
1
)
bbox_embed
=
self
.
bbox_embedding
.
weight
bbox_embed
=
self
.
bbox_embedding
.
weight
query_embedding
=
query_embedding
+
bbox_embed
[
None
,
None
]
query_embedding
=
query_embedding
+
bbox_embed
[
None
,
None
]
query_embedding
=
query_embedding
.
view
(
B
,
-
1
,
C
*
2
)
query_embedding
=
query_embedding
.
view
(
B
,
-
1
,
C
*
2
)
img_masks
=
x
.
new_zeros
((
B
,
H
,
W
))
img_masks
=
x
.
new_zeros
((
B
,
H
,
W
))
pos_embed
=
self
.
positional_encoding
(
img_masks
)
pos_embed
=
self
.
positional_encoding
(
img_masks
)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs
,
init_reference
,
inter_references
=
self
.
transformer
(
hs
,
init_reference
,
inter_references
=
self
.
transformer
(
[
x
,],
[
x
,],
[
img_masks
.
type
(
torch
.
bool
)],
[
img_masks
.
type
(
torch
.
bool
)],
query_embedding
,
query_embedding
,
[
pos_embed
],
[
pos_embed
],
reg_branches
=
self
.
reg_branches
if
self
.
iterative
else
None
,
# noqa:E501
reg_branches
=
self
.
reg_branches
if
self
.
iterative
else
None
,
# noqa:E501
cls_branches
=
None
,
# noqa:E501
cls_branches
=
None
,
# noqa:E501
)
)
outs_dec
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outs_dec
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs
=
[]
outputs
=
[]
for
i
,
(
query_feat
)
in
enumerate
(
outs_dec
):
for
i
,
(
query_feat
)
in
enumerate
(
outs_dec
):
if
i
==
0
:
if
i
==
0
:
reference
=
init_reference
reference
=
init_reference
else
:
else
:
reference
=
inter_references
[
i
-
1
]
reference
=
inter_references
[
i
-
1
]
outputs
.
append
(
self
.
get_prediction
(
i
,
query_feat
,
reference
))
outputs
.
append
(
self
.
get_prediction
(
i
,
query_feat
,
reference
))
return
outputs
return
outputs
def
get_prediction
(
self
,
level
,
query_feat
,
reference
):
def
get_prediction
(
self
,
level
,
query_feat
,
reference
):
bs
,
num_query
,
h
=
query_feat
.
shape
bs
,
num_query
,
h
=
query_feat
.
shape
query_feat
=
query_feat
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
h
)
query_feat
=
query_feat
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
h
)
ocls
=
self
.
pre_branches
[
'cls'
][
level
](
query_feat
.
flatten
(
-
2
))
ocls
=
self
.
pre_branches
[
'cls'
][
level
](
query_feat
.
flatten
(
-
2
))
# ocls = ocls.mean(-2)
# ocls = ocls.mean(-2)
reference
=
inverse_sigmoid
(
reference
)
reference
=
inverse_sigmoid
(
reference
)
reference
=
reference
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
self
.
coord_dim
)
reference
=
reference
.
view
(
bs
,
-
1
,
self
.
bbox_size
,
self
.
coord_dim
)
tmp
=
self
.
pre_branches
[
'reg'
][
level
](
query_feat
)
tmp
=
self
.
pre_branches
[
'reg'
][
level
](
query_feat
)
tmp
[...,:
self
.
kp_coord_dim
]
=
tmp
[...,:
self
.
kp_coord_dim
]
+
reference
[...,:
self
.
kp_coord_dim
]
tmp
[...,:
self
.
kp_coord_dim
]
=
tmp
[...,:
self
.
kp_coord_dim
]
+
reference
[...,:
self
.
kp_coord_dim
]
lines
=
tmp
.
sigmoid
()
# bs, num_query, self.bbox_size,2
lines
=
tmp
.
sigmoid
()
# bs, num_query, self.bbox_size,2
lines
=
lines
*
self
.
canvas_size
[:
self
.
coord_dim
]
lines
=
lines
*
self
.
canvas_size
[:
self
.
coord_dim
]
lines
=
lines
.
flatten
(
-
2
)
lines
=
lines
.
flatten
(
-
2
)
return
dict
(
return
dict
(
lines
=
lines
,
# [bs, num_query, bboxsize*2]
lines
=
lines
,
# [bs, num_query, bboxsize*2]
scores
=
ocls
,
# [bs, num_query, num_class]
scores
=
ocls
,
# [bs, num_query, num_class]
embeddings
=
query_feat
,
# [bs, num_query, bbox_size, h]
embeddings
=
query_feat
,
# [bs, num_query, bbox_size, h]
)
)
@
force_fp32
(
apply_to
=
(
'score_pred'
,
'lines_pred'
,
'gt_lines'
))
@
force_fp32
(
apply_to
=
(
'score_pred'
,
'lines_pred'
,
'gt_lines'
))
def
_get_target_single
(
self
,
def
_get_target_single
(
self
,
score_pred
,
score_pred
,
lines_pred
,
lines_pred
,
gt_labels
,
gt_labels
,
gt_lines
,
gt_lines
,
gt_bboxes_ignore
=
None
):
gt_bboxes_ignore
=
None
):
"""
"""
Compute regression and classification targets for one image.
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Outputs from a single decoder layer of a single feature level are used.
Args:
Args:
cls_score (Tensor): Box score logits from a single decoder layer
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
lines_pred (Tensor):
shape [num_query, num_points, 2].
shape [num_query, num_points, 2].
gt_lines (Tensor):
gt_lines (Tensor):
shape [num_gt, num_points, 2].
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
gt_labels (torch.LongTensor)
shape [num_gt, ]
shape [num_gt, ]
Returns:
Returns:
tuple[Tensor]: a tuple containing the following for one image.
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
"""
num_pred_lines
=
len
(
lines_pred
)
num_pred_lines
=
len
(
lines_pred
)
# assigner and sampler
# assigner and sampler
assign_result
=
self
.
assigner
.
assign
(
preds
=
dict
(
lines
=
lines_pred
,
scores
=
score_pred
,),
assign_result
=
self
.
assigner
.
assign
(
preds
=
dict
(
lines
=
lines_pred
,
scores
=
score_pred
,),
gts
=
dict
(
lines
=
gt_lines
,
gts
=
dict
(
lines
=
gt_lines
,
labels
=
gt_labels
,
),
labels
=
gt_labels
,
),
gt_bboxes_ignore
=
gt_bboxes_ignore
)
gt_bboxes_ignore
=
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
lines_pred
,
gt_lines
)
assign_result
,
lines_pred
,
gt_lines
)
pos_inds
=
sampling_result
.
pos_inds
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
neg_inds
=
sampling_result
.
neg_inds
pos_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
pos_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
# label targets 0: foreground, 1: background
# label targets 0: foreground, 1: background
if
self
.
separate_detect
:
if
self
.
separate_detect
:
labels
=
gt_lines
.
new_full
((
num_pred_lines
,
),
1
,
dtype
=
torch
.
long
)
labels
=
gt_lines
.
new_full
((
num_pred_lines
,
),
1
,
dtype
=
torch
.
long
)
else
:
else
:
labels
=
gt_lines
.
new_full
(
labels
=
gt_lines
.
new_full
(
(
num_pred_lines
,
),
self
.
num_classes
,
dtype
=
torch
.
long
)
(
num_pred_lines
,
),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
label_weights
=
gt_lines
.
new_ones
(
num_pred_lines
)
label_weights
=
gt_lines
.
new_ones
(
num_pred_lines
)
# bbox targets since lines_pred's last dimension is the vocabulary
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
# and ground truth dose not have this dimension.
if
self
.
discrete_output
:
if
self
.
discrete_output
:
lines_target
=
torch
.
zeros_like
(
lines_pred
[...,
0
]).
long
()
lines_target
=
torch
.
zeros_like
(
lines_pred
[...,
0
]).
long
()
lines_weights
=
torch
.
zeros_like
(
lines_pred
[...,
0
])
lines_weights
=
torch
.
zeros_like
(
lines_pred
[...,
0
])
else
:
else
:
lines_target
=
torch
.
zeros_like
(
lines_pred
)
lines_target
=
torch
.
zeros_like
(
lines_pred
)
lines_weights
=
torch
.
zeros_like
(
lines_pred
)
lines_weights
=
torch
.
zeros_like
(
lines_pred
)
lines_target
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
.
type
(
lines_target
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
.
type
(
lines_target
.
dtype
)
lines_target
.
dtype
)
lines_weights
[
pos_inds
]
=
1.0
lines_weights
[
pos_inds
]
=
1.0
n
=
lines_weights
.
sum
(
-
1
,
keepdim
=
True
)
n
=
lines_weights
.
sum
(
-
1
,
keepdim
=
True
)
lines_weights
=
lines_weights
/
n
.
masked_fill
(
n
==
0
,
1
)
lines_weights
=
lines_weights
/
n
.
masked_fill
(
n
==
0
,
1
)
return
(
labels
,
label_weights
,
lines_target
,
lines_weights
,
return
(
labels
,
label_weights
,
lines_target
,
lines_weights
,
pos_inds
,
neg_inds
,
pos_gt_inds
)
pos_inds
,
neg_inds
,
pos_gt_inds
)
# @force_fp32(apply_to=('preds', 'gts'))
# @force_fp32(apply_to=('preds', 'gts'))
def
get_targets
(
self
,
preds
,
gts
,
gt_bboxes_ignore_list
=
None
):
def
get_targets
(
self
,
preds
,
gts
,
gt_bboxes_ignore_list
=
None
):
"""
"""
Compute regression and classification targets for a batch image.
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Outputs from a single decoder layer of a single feature level are used.
Args:
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
decoder layer for each image with shape [num_query,
cls_out_channels].
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
boxes which can be ignored for each image. Default None.
Returns:
Returns:
tuple: a tuple containing the following targets.
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
- label_weights_list (list[Tensor]): Label weights for all
\
images.
images.
- lines_targets_list (list[Tensor]): Lines targets for all \
- lines_targets_list (list[Tensor]): Lines targets for all
\
images.
images.
- lines_weight_list (list[Tensor]): Lines weights for all \
- lines_weight_list (list[Tensor]): Lines weights for all
\
images.
images.
- num_total_pos (int): Number of positive samples in all \
- num_total_pos (int): Number of positive samples in all
\
images.
images.
- num_total_neg (int): Number of negative samples in all \
- num_total_neg (int): Number of negative samples in all
\
images.
images.
"""
"""
assert
gt_bboxes_ignore_list
is
None
,
\
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
# format the inputs
if
self
.
separate_detect
:
if
self
.
separate_detect
:
bbox
=
[
b
[
m
]
for
b
,
m
in
zip
(
gts
[
'bbox'
],
gts
[
'bbox_mask'
])]
bbox
=
[
b
[
m
]
for
b
,
m
in
zip
(
gts
[
'bbox'
],
gts
[
'bbox_mask'
])]
class_label
=
torch
.
zeros_like
(
gts
[
'bbox_mask'
]).
long
()
class_label
=
torch
.
zeros_like
(
gts
[
'bbox_mask'
]).
long
()
class_label
=
[
b
[
m
]
for
b
,
m
in
zip
(
class_label
,
gts
[
'bbox_mask'
])]
class_label
=
[
b
[
m
]
for
b
,
m
in
zip
(
class_label
,
gts
[
'bbox_mask'
])]
else
:
else
:
class_label
=
gts
[
'class_label'
]
class_label
=
gts
[
'class_label'
]
bbox
=
gts
[
'bbox'
]
bbox
=
gts
[
'bbox'
]
if
self
.
discrete_output
:
if
self
.
discrete_output
:
lines_pred
=
preds
[
'lines'
].
logits
lines_pred
=
preds
[
'lines'
].
logits
else
:
else
:
lines_pred
=
preds
[
'lines'
]
lines_pred
=
preds
[
'lines'
]
bbox
=
[
b
.
float
()
for
b
in
bbox
]
bbox
=
[
b
.
float
()
for
b
in
bbox
]
(
labels_list
,
label_weights_list
,
(
labels_list
,
label_weights_list
,
lines_targets_list
,
lines_weights_list
,
lines_targets_list
,
lines_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_gt_inds_list
)
=
multi_apply
(
pos_inds_list
,
neg_inds_list
,
pos_gt_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
self
.
_get_target_single
,
preds
[
'scores'
],
lines_pred
,
preds
[
'scores'
],
lines_pred
,
class_label
,
bbox
,
class_label
,
bbox
,
gt_bboxes_ignore
=
gt_bboxes_ignore_list
)
gt_bboxes_ignore
=
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
new_gts
=
dict
(
new_gts
=
dict
(
labels
=
labels_list
,
labels
=
labels_list
,
label_weights
=
label_weights_list
,
label_weights
=
label_weights_list
,
bboxs
=
lines_targets_list
,
bboxs
=
lines_targets_list
,
bboxs_weights
=
lines_weights_list
,
bboxs_weights
=
lines_weights_list
,
)
)
return
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
return
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
# @force_fp32(apply_to=('preds', 'gts'))
def
loss_single
(
self
,
def
loss_single
(
self
,
preds
:
dict
,
preds
:
dict
,
gts
:
dict
,
gts
:
dict
,
gt_bboxes_ignore_list
=
None
,
gt_bboxes_ignore_list
=
None
,
reduction
=
'none'
):
reduction
=
'none'
):
"""
"""
Loss function for outputs from a single decoder layer of a single
Loss function for outputs from a single decoder layer of a single
feature level.
feature level.
Args:
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
boxes which can be ignored for each image. Default None.
Returns:
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
a single decoder layer.
"""
"""
# Get target for each sample
# Get target for each sample
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
=
\
new_gts
,
num_total_pos
,
num_total_neg
,
pos_inds_list
,
pos_gt_inds_list
=
\
self
.
get_targets
(
preds
,
gts
,
gt_bboxes_ignore_list
)
self
.
get_targets
(
preds
,
gts
,
gt_bboxes_ignore_list
)
# Batched all data
# Batched all data
for
k
,
v
in
new_gts
.
items
():
for
k
,
v
in
new_gts
.
items
():
new_gts
[
k
]
=
torch
.
stack
(
v
,
dim
=
0
)
new_gts
[
k
]
=
torch
.
stack
(
v
,
dim
=
0
)
# construct weighted avg_factor to match with the official DETR repo
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_avg_factor
=
reduce_mean
(
preds
[
'scores'
].
new_tensor
([
cls_avg_factor
]))
preds
[
'scores'
].
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
# Classification loss
# Classification loss
if
self
.
separate_detect
:
if
self
.
separate_detect
:
loss_cls
=
self
.
bce_loss
(
loss_cls
=
self
.
bce_loss
(
preds
[
'scores'
],
new_gts
[
'labels'
],
new_gts
[
'label_weights'
],
cls_avg_factor
)
preds
[
'scores'
],
new_gts
[
'labels'
],
new_gts
[
'label_weights'
],
cls_avg_factor
)
else
:
else
:
# since the inputs needs the second dim is the class dim, we permute the prediction.
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores
=
preds
[
'scores'
].
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_scores
=
preds
[
'scores'
].
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_labels
=
new_gts
[
'labels'
].
reshape
(
-
1
)
cls_labels
=
new_gts
[
'labels'
].
reshape
(
-
1
)
cls_weights
=
new_gts
[
'label_weights'
].
reshape
(
-
1
)
cls_weights
=
new_gts
[
'label_weights'
].
reshape
(
-
1
)
loss_cls
=
self
.
loss_cls
(
loss_cls
=
self
.
loss_cls
(
cls_scores
,
cls_labels
,
cls_weights
,
avg_factor
=
cls_avg_factor
)
cls_scores
,
cls_labels
,
cls_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# position NLL loss
# position NLL loss
if
self
.
discrete_output
:
if
self
.
discrete_output
:
loss_reg
=
-
(
preds
[
'lines'
].
log_prob
(
new_gts
[
'bboxs'
])
*
loss_reg
=
-
(
preds
[
'lines'
].
log_prob
(
new_gts
[
'bboxs'
])
*
new_gts
[
'bboxs_weights'
]).
sum
()
/
(
num_total_pos
)
new_gts
[
'bboxs_weights'
]).
sum
()
/
(
num_total_pos
)
else
:
else
:
loss_reg
=
self
.
reg_loss
(
loss_reg
=
self
.
reg_loss
(
preds
[
'lines'
],
new_gts
[
'bboxs'
],
new_gts
[
'bboxs_weights'
],
avg_factor
=
num_total_pos
)
preds
[
'lines'
],
new_gts
[
'bboxs'
],
new_gts
[
'bboxs_weights'
],
avg_factor
=
num_total_pos
)
loss_dict
=
dict
(
loss_dict
=
dict
(
cls
=
loss_cls
,
cls
=
loss_cls
,
reg
=
loss_reg
,
reg
=
loss_reg
,
)
)
return
loss_dict
,
pos_inds_list
,
pos_gt_inds_list
return
loss_dict
,
pos_inds_list
,
pos_gt_inds_list
@
force_fp32
(
apply_to
=
(
'gt_lines_list'
,
'preds_dicts'
))
@
force_fp32
(
apply_to
=
(
'gt_lines_list'
,
'preds_dicts'
))
def
loss
(
self
,
def
loss
(
self
,
gts
:
dict
,
gts
:
dict
,
preds_dicts
:
dict
,
preds_dicts
:
dict
,
gt_bboxes_ignore
=
None
,
gt_bboxes_ignore
=
None
,
reduction
=
'mean'
):
reduction
=
'mean'
):
"""
"""
Loss Function.
Loss Function.
Args:
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
image with shape (num_gts, ).
preds_dicts:
preds_dicts:
all_cls_scores (Tensor): Classification score of all
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
which can be ignored for each image. Default None.
Returns:
Returns:
dict[str, Tensor]: A dictionary of loss components.
dict[str, Tensor]: A dictionary of loss components.
"""
"""
assert
gt_bboxes_ignore
is
None
,
\
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
f
'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
# Since there might have multi layer
losses
,
pos_inds_lists
,
pos_gt_inds_lists
=
multi_apply
(
losses
,
pos_inds_lists
,
pos_gt_inds_lists
=
multi_apply
(
self
.
loss_single
,
self
.
loss_single
,
preds_dicts
,
preds_dicts
,
gts
=
gts
,
gts
=
gts
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
reduction
=
reduction
)
reduction
=
reduction
)
# Format the losses
# Format the losses
loss_dict
=
dict
()
loss_dict
=
dict
()
# loss from the last decoder layer
# loss from the last decoder layer
for
k
,
v
in
losses
[
-
1
].
items
():
for
k
,
v
in
losses
[
-
1
].
items
():
loss_dict
[
k
]
=
v
loss_dict
[
k
]
=
v
# Loss from other decoder layers
# Loss from other decoder layers
num_dec_layer
=
0
num_dec_layer
=
0
for
loss
in
losses
[:
-
1
]:
for
loss
in
losses
[:
-
1
]:
for
k
,
v
in
loss
.
items
():
for
k
,
v
in
loss
.
items
():
loss_dict
[
f
'd
{
num_dec_layer
}
.
{
k
}
'
]
=
v
loss_dict
[
f
'd
{
num_dec_layer
}
.
{
k
}
'
]
=
v
num_dec_layer
+=
1
num_dec_layer
+=
1
return
loss_dict
,
pos_inds_lists
,
pos_gt_inds_lists
return
loss_dict
,
pos_inds_lists
,
pos_gt_inds_lists
def
post_process
(
self
,
preds_dicts
:
list
,
**
kwargs
):
def
post_process
(
self
,
preds_dicts
:
list
,
**
kwargs
):
'''
'''
Args:
Args:
preds_dicts:
preds_dicts:
scores (Tensor): Classification score of all
scores (Tensor): Classification score of all
decoder layers, has shape
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
Outs:
ret_list (List[Dict]) with length as bs
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
list of result dict for each sample in the batch
XXX
XXX
'''
'''
preds
=
preds_dicts
[
-
1
]
preds
=
preds_dicts
[
-
1
]
batched_cls_scores
=
preds
[
'scores'
]
batched_cls_scores
=
preds
[
'scores'
]
batched_lines_preds
=
preds
[
'lines'
]
batched_lines_preds
=
preds
[
'lines'
]
batch_size
=
batched_cls_scores
.
size
(
0
)
batch_size
=
batched_cls_scores
.
size
(
0
)
device
=
batched_cls_scores
.
device
device
=
batched_cls_scores
.
device
result_dict
=
{
result_dict
=
{
'bbox'
:
[],
'bbox'
:
[],
'scores'
:
[],
'scores'
:
[],
'labels'
:
[],
'labels'
:
[],
'bbox_flat'
:
[],
'bbox_flat'
:
[],
'lines_cls'
:
[],
'lines_cls'
:
[],
'lines_bs_idx'
:
[],
'lines_bs_idx'
:
[],
}
}
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
cls_scores
=
batched_cls_scores
[
i
]
cls_scores
=
batched_cls_scores
[
i
]
det_preds
=
batched_lines_preds
[
i
]
det_preds
=
batched_lines_preds
[
i
]
max_num
=
self
.
max_lines
max_num
=
self
.
max_lines
if
self
.
loss_cls
.
use_sigmoid
:
if
self
.
loss_cls
.
use_sigmoid
:
cls_scores
=
cls_scores
.
sigmoid
()
cls_scores
=
cls_scores
.
sigmoid
()
scores
,
valid_idx
=
cls_scores
.
view
(
-
1
).
topk
(
max_num
)
scores
,
valid_idx
=
cls_scores
.
view
(
-
1
).
topk
(
max_num
)
det_labels
=
valid_idx
%
self
.
num_classes
det_labels
=
valid_idx
%
self
.
num_classes
valid_idx
=
valid_idx
//
self
.
num_classes
valid_idx
=
valid_idx
//
self
.
num_classes
det_preds
=
det_preds
[
valid_idx
]
det_preds
=
det_preds
[
valid_idx
]
else
:
else
:
scores
,
det_labels
=
F
.
softmax
(
cls_scores
,
dim
=-
1
)[...,
:
-
1
].
max
(
-
1
)
scores
,
det_labels
=
F
.
softmax
(
cls_scores
,
dim
=-
1
)[...,
:
-
1
].
max
(
-
1
)
scores
,
valid_idx
=
scores
.
topk
(
max_num
)
scores
,
valid_idx
=
scores
.
topk
(
max_num
)
det_preds
=
det_preds
[
valid_idx
]
det_preds
=
det_preds
[
valid_idx
]
det_labels
=
det_labels
[
valid_idx
]
det_labels
=
det_labels
[
valid_idx
]
nline
=
len
(
valid_idx
)
nline
=
len
(
valid_idx
)
result_dict
[
'bbox'
].
append
(
det_preds
)
result_dict
[
'bbox'
].
append
(
det_preds
)
result_dict
[
'scores'
].
append
(
scores
)
result_dict
[
'scores'
].
append
(
scores
)
result_dict
[
'labels'
].
append
(
det_labels
)
result_dict
[
'labels'
].
append
(
det_labels
)
result_dict
[
'lines_bs_idx'
].
extend
([
i
]
*
nline
)
result_dict
[
'lines_bs_idx'
].
extend
([
i
]
*
nline
)
# for down stream polyline
# for down stream polyline
_bboxs
=
torch
.
cat
(
result_dict
[
'bbox'
],
dim
=
0
)
_bboxs
=
torch
.
cat
(
result_dict
[
'bbox'
],
dim
=
0
)
# quantize the data
# quantize the data
result_dict
[
'bbox_flat'
]
=
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
result_dict
[
'bbox_flat'
]
=
torch
.
round
(
_bboxs
).
type
(
torch
.
int32
)
result_dict
[
'lines_cls'
]
=
torch
.
cat
(
result_dict
[
'lines_cls'
]
=
torch
.
cat
(
result_dict
[
'labels'
],
dim
=
0
).
long
()
result_dict
[
'labels'
],
dim
=
0
).
long
()
result_dict
[
'lines_bs_idx'
]
=
torch
.
tensor
(
result_dict
[
'lines_bs_idx'
]
=
torch
.
tensor
(
result_dict
[
'lines_bs_idx'
],
device
=
device
).
long
()
result_dict
[
'lines_bs_idx'
],
device
=
device
).
long
()
return
result_dict
return
result_dict
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/heads/polyline_generator.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/heads/polyline_generator.py
View file @
f3b13cad
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.distributions.categorical
import
Categorical
from
torch.distributions.categorical
import
Categorical
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.detgen_utils.causal_trans
import
(
CausalTransformerDecoder
,
from
.detgen_utils.causal_trans
import
(
CausalTransformerDecoder
,
CausalTransformerDecoderLayer
)
CausalTransformerDecoderLayer
)
from
.detgen_utils.utils
import
(
dequantize_verts
,
generate_square_subsequent_mask
,
from
.detgen_utils.utils
import
(
dequantize_verts
,
generate_square_subsequent_mask
,
quantize_verts
,
top_k_logits
,
top_p_logits
)
quantize_verts
,
top_k_logits
,
top_p_logits
)
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
HEADS
.
register_module
(
force
=
True
)
@
HEADS
.
register_module
(
force
=
True
)
class
PolylineGenerator
(
nn
.
Module
):
class
PolylineGenerator
(
nn
.
Module
):
"""
"""
Autoregressive generative model of n-gon meshes.
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
distribution over vertex indices.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
encoder_config
,
encoder_config
,
decoder_config
,
decoder_config
,
class_conditional
=
True
,
class_conditional
=
True
,
num_classes
=
55
,
num_classes
=
55
,
decoder_cross_attention
=
True
,
decoder_cross_attention
=
True
,
use_discrete_vertex_embeddings
=
True
,
use_discrete_vertex_embeddings
=
True
,
condition_points_num
=
3
,
condition_points_num
=
3
,
coord_dim
=
2
,
coord_dim
=
2
,
canvas_size
=
(
400
,
200
),
canvas_size
=
(
400
,
200
),
max_seq_length
=
500
,
max_seq_length
=
500
,
name
=
'gen_model'
):
name
=
'gen_model'
):
"""Initializes FaceModel.
"""Initializes FaceModel.
Args:
Args:
encoder_config: Dictionary with TransformerEncoder config.
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
embeddings.
name: Name of variable scope
name: Name of variable scope
"""
"""
super
(
PolylineGenerator
,
self
).
__init__
()
super
(
PolylineGenerator
,
self
).
__init__
()
self
.
embedding_dim
=
decoder_config
[
'layer_config'
][
'd_model'
]
self
.
embedding_dim
=
decoder_config
[
'layer_config'
][
'd_model'
]
self
.
class_conditional
=
class_conditional
self
.
class_conditional
=
class_conditional
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
self
.
decoder_cross_attention
=
decoder_cross_attention
self
.
decoder_cross_attention
=
decoder_cross_attention
self
.
use_discrete_vertex_embeddings
=
use_discrete_vertex_embeddings
self
.
use_discrete_vertex_embeddings
=
use_discrete_vertex_embeddings
self
.
condition_points_num
=
condition_points_num
self
.
condition_points_num
=
condition_points_num
self
.
fp16_enabled
=
False
self
.
fp16_enabled
=
False
self
.
coord_dim
=
coord_dim
# if we use xyz else 2 when we use xy
self
.
coord_dim
=
coord_dim
# if we use xyz else 2 when we use xy
self
.
kp_coord_dim
=
coord_dim
if
coord_dim
==
2
else
2
# XXX
self
.
kp_coord_dim
=
coord_dim
if
coord_dim
==
2
else
2
# XXX
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
self
.
register_buffer
(
'canvas_size'
,
torch
.
tensor
(
canvas_size
))
# initialize the model
# initialize the model
self
.
_project_to_logits
=
nn
.
Linear
(
self
.
_project_to_logits
=
nn
.
Linear
(
self
.
embedding_dim
,
self
.
embedding_dim
,
max
(
canvas_size
)
+
1
,
# + 1 for stopping token. use_bias=True,
max
(
canvas_size
)
+
1
,
# + 1 for stopping token. use_bias=True,
)
)
self
.
input_proj
=
nn
.
Conv2d
(
self
.
input_proj
=
nn
.
Conv2d
(
in_channels
,
self
.
embedding_dim
,
kernel_size
=
1
)
in_channels
,
self
.
embedding_dim
,
kernel_size
=
1
)
decoder_layer
=
CausalTransformerDecoderLayer
(
decoder_layer
=
CausalTransformerDecoderLayer
(
**
decoder_config
.
pop
(
'layer_config'
))
**
decoder_config
.
pop
(
'layer_config'
))
self
.
decoder
=
CausalTransformerDecoder
(
self
.
decoder
=
CausalTransformerDecoder
(
decoder_layer
,
**
decoder_config
)
decoder_layer
,
**
decoder_config
)
self
.
_init_embedding
()
self
.
_init_embedding
()
self
.
init_weights
()
self
.
init_weights
()
def
_init_embedding
(
self
):
def
_init_embedding
(
self
):
if
self
.
class_conditional
:
if
self
.
class_conditional
:
self
.
label_embed
=
nn
.
Embedding
(
self
.
label_embed
=
nn
.
Embedding
(
self
.
num_classes
,
self
.
embedding_dim
)
self
.
num_classes
,
self
.
embedding_dim
)
self
.
coord_embed
=
nn
.
Embedding
(
self
.
coord_dim
,
self
.
embedding_dim
)
self
.
coord_embed
=
nn
.
Embedding
(
self
.
coord_dim
,
self
.
embedding_dim
)
self
.
pos_embeddings
=
nn
.
Embedding
(
self
.
pos_embeddings
=
nn
.
Embedding
(
self
.
max_seq_length
,
self
.
embedding_dim
)
self
.
max_seq_length
,
self
.
embedding_dim
)
# to indicate the role of the position is the start of the line or the end of it.
# to indicate the role of the position is the start of the line or the end of it.
self
.
bbox_context_embed
=
\
self
.
bbox_context_embed
=
\
nn
.
Embedding
(
self
.
condition_points_num
,
self
.
embedding_dim
)
nn
.
Embedding
(
self
.
condition_points_num
,
self
.
embedding_dim
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embedding_dim
)
self
.
img_coord_embed
=
nn
.
Linear
(
2
,
self
.
embedding_dim
)
# initialize the verteices embedding
# initialize the verteices embedding
if
self
.
use_discrete_vertex_embeddings
:
if
self
.
use_discrete_vertex_embeddings
:
self
.
vertex_embed
=
nn
.
Embedding
(
self
.
vertex_embed
=
nn
.
Embedding
(
max
(
self
.
canvas_size
)
+
1
,
self
.
embedding_dim
)
max
(
self
.
canvas_size
)
+
1
,
self
.
embedding_dim
)
else
:
else
:
self
.
vertex_embed
=
nn
.
Linear
(
1
,
self
.
embedding_dim
)
self
.
vertex_embed
=
nn
.
Linear
(
1
,
self
.
embedding_dim
)
def
init_weights
(
self
):
def
init_weights
(
self
):
for
p
in
self
.
parameters
():
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
nn
.
init
.
xavier_uniform_
(
p
)
def
_embed_kps
(
self
,
bbox
):
def
_embed_kps
(
self
,
bbox
):
bbox_len
=
bbox
.
shape
[
-
1
]
bbox_len
=
bbox
.
shape
[
-
1
]
# Bbox_context
# Bbox_context
bbox_embedding
=
self
.
bbox_context_embed
(
bbox_embedding
=
self
.
bbox_context_embed
(
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
/
self
.
kp_coord_dim
).
floor
().
long
())
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
/
self
.
kp_coord_dim
).
floor
().
long
())
# Coord indicators (x, y)
# Coord indicators (x, y)
coord_embeddings
=
self
.
coord_embed
(
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
%
self
.
kp_coord_dim
)
torch
.
arange
(
bbox_len
,
device
=
bbox
.
device
)
%
self
.
kp_coord_dim
)
# Discrete vertex value embeddings
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
bbox
)
vert_embeddings
=
self
.
vertex_embed
(
bbox
)
return
vert_embeddings
+
(
bbox_embedding
+
coord_embeddings
)[
None
]
return
vert_embeddings
+
(
bbox_embedding
+
coord_embeddings
)[
None
]
def
_prepare_context
(
self
,
batch
,
context
):
def
_prepare_context
(
self
,
batch
,
context
):
"""Prepare class label and vertex context."""
"""Prepare class label and vertex context."""
global_context_embedding
=
None
global_context_embedding
=
None
if
self
.
class_conditional
:
if
self
.
class_conditional
:
global_context_embedding
=
self
.
label_embed
(
batch
[
'lines_cls'
])
global_context_embedding
=
self
.
label_embed
(
batch
[
'lines_cls'
])
bbox_embeddings
=
self
.
_embed_kps
(
batch
[
'bbox_flat'
])
bbox_embeddings
=
self
.
_embed_kps
(
batch
[
'bbox_flat'
])
if
global_context_embedding
is
not
None
:
if
global_context_embedding
is
not
None
:
global_context_embedding
=
torch
.
cat
(
global_context_embedding
=
torch
.
cat
(
[
global_context_embedding
[:,
None
],
bbox_embeddings
],
dim
=
1
)
[
global_context_embedding
[:,
None
],
bbox_embeddings
],
dim
=
1
)
# Pass images through encoder
# Pass images through encoder
image_embeddings
=
assign_bev
(
image_embeddings
=
assign_bev
(
context
[
'bev_embeddings'
],
batch
[
'lines_bs_idx'
])
context
[
'bev_embeddings'
],
batch
[
'lines_bs_idx'
])
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
image_embeddings
=
self
.
input_proj
(
image_embeddings
)
device
=
image_embeddings
.
device
device
=
image_embeddings
.
device
# Add 2D coordinate grid embedding
# Add 2D coordinate grid embedding
H
,
W
=
image_embeddings
.
shape
[
2
:]
H
,
W
=
image_embeddings
.
shape
[
2
:]
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Ws
=
torch
.
linspace
(
-
1.
,
1.
,
W
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
Hs
=
torch
.
linspace
(
-
1.
,
1.
,
H
)
image_coords
=
torch
.
stack
(
image_coords
=
torch
.
stack
(
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
torch
.
meshgrid
(
Hs
,
Ws
),
dim
=-
1
).
to
(
device
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_coord_embeddings
=
self
.
img_coord_embed
(
image_coords
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
image_embeddings
+=
image_coord_embeddings
[
None
].
permute
(
0
,
3
,
1
,
2
)
# Reshape spatial grid to sequence
# Reshape spatial grid to sequence
B
=
image_embeddings
.
shape
[
0
]
B
=
image_embeddings
.
shape
[
0
]
sequential_context_embeddings
=
image_embeddings
.
reshape
(
sequential_context_embeddings
=
image_embeddings
.
reshape
(
B
,
self
.
embedding_dim
,
-
1
).
permute
(
0
,
2
,
1
)
B
,
self
.
embedding_dim
,
-
1
).
permute
(
0
,
2
,
1
)
return
(
global_context_embedding
,
return
(
global_context_embedding
,
sequential_context_embeddings
)
sequential_context_embeddings
)
def
_embed_inputs
(
self
,
seqs
,
condition_embedding
=
None
):
def
_embed_inputs
(
self
,
seqs
,
condition_embedding
=
None
):
"""Embeds face sequences and adds within and between face positions.
"""Embeds face sequences and adds within and between face positions.
Args:
Args:
seq: B, seqlen=vlen*3,
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
Returns:
embeddings: B, seqlen, h
embeddings: B, seqlen, h
"""
"""
B
,
seq_len
=
seqs
.
shape
[:
2
]
B
,
seq_len
=
seqs
.
shape
[:
2
]
# Position embeddings
# Position embeddings
pos_embeddings
=
self
.
pos_embeddings
(
pos_embeddings
=
self
.
pos_embeddings
(
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
/
self
.
coord_dim
).
floor
().
long
())
# seq_len, h
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
/
self
.
coord_dim
).
floor
().
long
())
# seq_len, h
# Coord indicators (x, y, z(optional))
# Coord indicators (x, y, z(optional))
coord_embeddings
=
self
.
coord_embed
(
coord_embeddings
=
self
.
coord_embed
(
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
%
self
.
coord_dim
)
torch
.
arange
(
seq_len
,
device
=
seqs
.
device
)
%
self
.
coord_dim
)
# Discrete vertex value embeddings
# Discrete vertex value embeddings
vert_embeddings
=
self
.
vertex_embed
(
seqs
)
vert_embeddings
=
self
.
vertex_embed
(
seqs
)
# Aggregate embeddings
# Aggregate embeddings
embeddings
=
vert_embeddings
+
\
embeddings
=
vert_embeddings
+
\
(
coord_embeddings
+
pos_embeddings
)[
None
]
(
coord_embeddings
+
pos_embeddings
)[
None
]
embeddings
=
torch
.
cat
([
condition_embedding
,
embeddings
],
dim
=
1
)
embeddings
=
torch
.
cat
([
condition_embedding
,
embeddings
],
dim
=
1
)
return
embeddings
return
embeddings
def
forward
(
self
,
batch
:
dict
,
**
kwargs
):
def
forward
(
self
,
batch
:
dict
,
**
kwargs
):
"""
"""
Pass batch through face model and get log probabilities.
Pass batch through face model and get log probabilities.
Args:
Args:
batch: Dictionary containing:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
faces.
'vertices_mask': float32 tensor with shape
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
"""
if
self
.
training
:
if
self
.
training
:
return
self
.
forward_train
(
batch
,
**
kwargs
)
return
self
.
forward_train
(
batch
,
**
kwargs
)
else
:
else
:
return
self
.
inference
(
batch
,
**
kwargs
)
return
self
.
inference
(
batch
,
**
kwargs
)
def
sperate_forward
(
self
,
batch
,
context
,
**
kwargs
):
def
sperate_forward
(
self
,
batch
,
context
,
**
kwargs
):
polyline_length
=
batch
[
'polyline_masks'
].
sum
(
-
1
)
polyline_length
=
batch
[
'polyline_masks'
].
sum
(
-
1
)
c1
,
c2
,
revert_idx
,
size
=
get_chunk_idx
(
polyline_length
)
c1
,
c2
,
revert_idx
,
size
=
get_chunk_idx
(
polyline_length
)
sizes
=
[
size
,
polyline_length
.
max
()]
sizes
=
[
size
,
polyline_length
.
max
()]
polyline_logits
=
[]
polyline_logits
=
[]
for
c_idx
,
size
in
zip
([
c1
,
c2
],
sizes
):
for
c_idx
,
size
in
zip
([
c1
,
c2
],
sizes
):
new_batch
=
assign_batch
(
batch
,
c_idx
,
size
)
new_batch
=
assign_batch
(
batch
,
c_idx
,
size
)
_poly_logits
=
self
.
_forward_train
(
new_batch
,
context
,
**
kwargs
)
_poly_logits
=
self
.
_forward_train
(
new_batch
,
context
,
**
kwargs
)
polyline_logits
.
append
(
_poly_logits
)
polyline_logits
.
append
(
_poly_logits
)
# maybe imporve the speed
# maybe imporve the speed
for
i
,
(
_poly_logits
,
size
)
in
enumerate
(
zip
(
polyline_logits
,
sizes
)):
for
i
,
(
_poly_logits
,
size
)
in
enumerate
(
zip
(
polyline_logits
,
sizes
)):
if
size
<
sizes
[
1
]:
if
size
<
sizes
[
1
]:
_poly_logits
=
F
.
pad
(
_poly_logits
,
(
0
,
0
,
0
,
sizes
[
1
]
-
size
),
"constant"
,
0
)
_poly_logits
=
F
.
pad
(
_poly_logits
,
(
0
,
0
,
0
,
sizes
[
1
]
-
size
),
"constant"
,
0
)
polyline_logits
[
i
]
=
_poly_logits
polyline_logits
[
i
]
=
_poly_logits
polyline_logits
=
torch
.
cat
(
polyline_logits
,
0
)
polyline_logits
=
torch
.
cat
(
polyline_logits
,
0
)
polyline_logits
=
polyline_logits
[
revert_idx
]
polyline_logits
=
polyline_logits
[
revert_idx
]
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
return
{
'polylines'
:
cat_dist
}
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
def
forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
"""
Returns:
Returns:
pred_dist: Categorical predictive distribution with batch shape
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
[batch_size, seq_length].
"""
"""
# we use the gt vertices
# we use the gt vertices
if
False
:
if
False
:
polyline_logits
=
self
.
_forward_train
(
batch
,
context
,
**
kwargs
)
polyline_logits
=
self
.
_forward_train
(
batch
,
context
,
**
kwargs
)
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
cat_dist
=
Categorical
(
logits
=
polyline_logits
)
return
{
'polylines'
:
cat_dist
}
return
{
'polylines'
:
cat_dist
}
else
:
else
:
return
self
.
sperate_forward
(
batch
,
context
,
**
kwargs
)
return
self
.
sperate_forward
(
batch
,
context
,
**
kwargs
)
def
_forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
def
_forward_train
(
self
,
batch
:
dict
,
context
:
dict
,
**
kwargs
):
"""
"""
Returns:
Returns:
pred_dist: Categorical predictive distribution with batch shape
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
[batch_size, seq_length].
"""
"""
# we use the gt vertices
# we use the gt vertices
global_context
,
seq_context
=
self
.
_prepare_context
(
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
batch
,
context
)
logits
=
self
.
body
(
logits
=
self
.
body
(
# Last element not used for preds
# Last element not used for preds
batch
[
'polylines'
][:,
:
-
1
],
batch
[
'polylines'
][:,
:
-
1
],
global_context_embedding
=
global_context
,
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
sequential_context_embeddings
=
seq_context
,
return_logits
=
True
,
return_logits
=
True
,
is_training
=
self
.
training
)
is_training
=
self
.
training
)
return
logits
return
logits
@
force_fp32
(
apply_to
=
(
'global_context_embedding'
,
'sequential_context_embeddings'
,
'cache'
))
@
force_fp32
(
apply_to
=
(
'global_context_embedding'
,
'sequential_context_embeddings'
,
'cache'
))
def
body
(
self
,
def
body
(
self
,
seqs
,
seqs
,
global_context_embedding
=
None
,
global_context_embedding
=
None
,
sequential_context_embeddings
=
None
,
sequential_context_embeddings
=
None
,
temperature
=
1.
,
temperature
=
1.
,
top_k
=
0
,
top_k
=
0
,
top_p
=
1.
,
top_p
=
1.
,
cache
=
None
,
cache
=
None
,
return_logits
=
False
,
return_logits
=
False
,
is_training
=
True
):
is_training
=
True
):
"""
"""
Outputs categorical dist for vertex indices.
Outputs categorical dist for vertex indices.
Body of the face model
Body of the face model
"""
"""
# Embed inputs
# Embed inputs
condition_len
=
global_context_embedding
.
shape
[
1
]
condition_len
=
global_context_embedding
.
shape
[
1
]
decoder_inputs
=
self
.
_embed_inputs
(
decoder_inputs
=
self
.
_embed_inputs
(
seqs
,
global_context_embedding
)
seqs
,
global_context_embedding
)
# Pass through Transformer decoder
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
# since our memory efficient decoder only support seq first setting.
decoder_inputs
=
decoder_inputs
.
transpose
(
0
,
1
)
decoder_inputs
=
decoder_inputs
.
transpose
(
0
,
1
)
if
sequential_context_embeddings
is
not
None
:
if
sequential_context_embeddings
is
not
None
:
sequential_context_embeddings
=
sequential_context_embeddings
.
transpose
(
sequential_context_embeddings
=
sequential_context_embeddings
.
transpose
(
0
,
1
)
0
,
1
)
causal_msk
=
None
causal_msk
=
None
if
is_training
:
if
is_training
:
causal_msk
=
generate_square_subsequent_mask
(
causal_msk
=
generate_square_subsequent_mask
(
decoder_inputs
.
shape
[
0
],
condition_len
=
condition_len
,
device
=
decoder_inputs
.
device
)
decoder_inputs
.
shape
[
0
],
condition_len
=
condition_len
,
device
=
decoder_inputs
.
device
)
decoder_outputs
,
cache
=
self
.
decoder
(
decoder_outputs
,
cache
=
self
.
decoder
(
tgt
=
decoder_inputs
,
tgt
=
decoder_inputs
,
cache
=
cache
,
cache
=
cache
,
memory
=
sequential_context_embeddings
,
memory
=
sequential_context_embeddings
,
causal_mask
=
causal_msk
,
causal_mask
=
causal_msk
,
)
)
decoder_outputs
=
decoder_outputs
.
transpose
(
0
,
1
)
decoder_outputs
=
decoder_outputs
.
transpose
(
0
,
1
)
# since we only need the predict seq
# since we only need the predict seq
decoder_outputs
=
decoder_outputs
[:,
condition_len
-
1
:]
decoder_outputs
=
decoder_outputs
[:,
condition_len
-
1
:]
# Get logits and optionally process for sampling
# Get logits and optionally process for sampling
logits
=
self
.
_project_to_logits
(
decoder_outputs
)
logits
=
self
.
_project_to_logits
(
decoder_outputs
)
# y mask
# y mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_y
=
(
_vert_mask
<
self
.
canvas_size
[
1
]
+
1
)
vertices_mask_y
=
(
_vert_mask
<
self
.
canvas_size
[
1
]
+
1
)
vertices_mask_y
[
0
]
=
False
# y position doesn't have stop sign
vertices_mask_y
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
1
::
self
.
coord_dim
]
=
logits
[:,
1
::
self
.
coord_dim
]
*
\
logits
[:,
1
::
self
.
coord_dim
]
=
logits
[:,
1
::
self
.
coord_dim
]
*
\
vertices_mask_y
-
~
vertices_mask_y
*
1e9
vertices_mask_y
-
~
vertices_mask_y
*
1e9
if
self
.
coord_dim
>
2
:
if
self
.
coord_dim
>
2
:
# z mask
# z mask
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
_vert_mask
=
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
logits
.
device
)
vertices_mask_z
=
(
_vert_mask
<
self
.
canvas_size
[
2
]
+
1
)
vertices_mask_z
=
(
_vert_mask
<
self
.
canvas_size
[
2
]
+
1
)
vertices_mask_z
[
0
]
=
False
# y position doesn't have stop sign
vertices_mask_z
[
0
]
=
False
# y position doesn't have stop sign
logits
[:,
2
::
self
.
coord_dim
]
=
logits
[:,
2
::
self
.
coord_dim
]
*
\
logits
[:,
2
::
self
.
coord_dim
]
=
logits
[:,
2
::
self
.
coord_dim
]
*
\
vertices_mask_z
-
~
vertices_mask_z
*
1e9
vertices_mask_z
-
~
vertices_mask_z
*
1e9
logits
=
logits
/
temperature
logits
=
logits
/
temperature
logits
=
top_k_logits
(
logits
,
top_k
)
logits
=
top_k_logits
(
logits
,
top_k
)
logits
=
top_p_logits
(
logits
,
top_p
)
logits
=
top_p_logits
(
logits
,
top_p
)
if
return_logits
:
if
return_logits
:
return
logits
return
logits
cat_dist
=
Categorical
(
logits
=
logits
)
cat_dist
=
Categorical
(
logits
=
logits
)
return
cat_dist
,
cache
return
cat_dist
,
cache
@
force_fp32
(
apply_to
=
(
'pred'
))
@
force_fp32
(
apply_to
=
(
'pred'
))
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
def
loss
(
self
,
gt
:
dict
,
pred
:
dict
):
weight
=
gt
[
'polyline_weights'
]
weight
=
gt
[
'polyline_weights'
]
mask
=
gt
[
'polyline_masks'
]
mask
=
gt
[
'polyline_masks'
]
loss
=
-
torch
.
sum
(
loss
=
-
torch
.
sum
(
pred
[
'polylines'
].
log_prob
(
gt
[
'polylines'
])
*
mask
*
weight
)
/
weight
.
sum
()
pred
[
'polylines'
].
log_prob
(
gt
[
'polylines'
])
*
mask
*
weight
)
/
weight
.
sum
()
return
{
'seq'
:
loss
}
return
{
'seq'
:
loss
}
def
inference
(
self
,
def
inference
(
self
,
batch
:
dict
,
batch
:
dict
,
context
:
dict
,
context
:
dict
,
max_sample_length
=
None
,
max_sample_length
=
None
,
temperature
=
1.
,
temperature
=
1.
,
top_k
=
0
,
top_k
=
0
,
top_p
=
1.
,
top_p
=
1.
,
only_return_complete
=
False
,
only_return_complete
=
False
,
gt_condition
=
False
,
gt_condition
=
False
,
**
kwargs
):
**
kwargs
):
"""Sample from face model using caching.
"""Sample from face model using caching.
Args:
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
return all samples along with completed indicator.
Returns:
Returns:
outputs: Output dictionary with fields:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
example in padded vertex samples.
"""
"""
# prepare the conditional variable
# prepare the conditional variable
global_context
,
seq_context
=
self
.
_prepare_context
(
global_context
,
seq_context
=
self
.
_prepare_context
(
batch
,
context
)
batch
,
context
)
device
=
global_context
.
device
device
=
global_context
.
device
batch_size
=
global_context
.
shape
[
0
]
batch_size
=
global_context
.
shape
[
0
]
# While loop sampling with caching
# While loop sampling with caching
samples
=
torch
.
empty
(
samples
=
torch
.
empty
(
[
batch_size
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
[
batch_size
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
max_sample_length
=
max_sample_length
or
self
.
max_seq_length
max_sample_length
=
max_sample_length
or
self
.
max_seq_length
seq_len
=
max_sample_length
*
self
.
coord_dim
+
1
seq_len
=
max_sample_length
*
self
.
coord_dim
+
1
cache
=
None
cache
=
None
decoded_tokens
=
\
decoded_tokens
=
\
torch
.
zeros
((
batch_size
,
seq_len
),
torch
.
zeros
((
batch_size
,
seq_len
),
device
=
device
,
dtype
=
torch
.
long
)
device
=
device
,
dtype
=
torch
.
long
)
remain_idx
=
torch
.
arange
(
batch_size
,
device
=
device
)
remain_idx
=
torch
.
arange
(
batch_size
,
device
=
device
)
for
i
in
range
(
seq_len
):
for
i
in
range
(
seq_len
):
# While-loop body for autoregression calculation.
# While-loop body for autoregression calculation.
pred_dist
,
cache
=
self
.
body
(
pred_dist
,
cache
=
self
.
body
(
samples
,
samples
,
global_context_embedding
=
global_context
,
global_context_embedding
=
global_context
,
sequential_context_embeddings
=
seq_context
,
sequential_context_embeddings
=
seq_context
,
cache
=
cache
,
cache
=
cache
,
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
top_k
,
top_k
=
top_k
,
top_p
=
top_p
,
top_p
=
top_p
,
is_training
=
False
)
is_training
=
False
)
samples
=
pred_dist
.
sample
()
samples
=
pred_dist
.
sample
()
decoded_tokens
[
remain_idx
,
i
]
=
samples
[:,
-
1
]
decoded_tokens
[
remain_idx
,
i
]
=
samples
[:,
-
1
]
# Stopping conditions for autoregressive calculation.
# Stopping conditions for autoregressive calculation.
if
not
(
decoded_tokens
[:,:
i
+
1
]
!=
0
).
all
(
-
1
).
any
():
if
not
(
decoded_tokens
[:,:
i
+
1
]
!=
0
).
all
(
-
1
).
any
():
break
break
# update state, check the new position is zero.
# update state, check the new position is zero.
valid_idx
=
(
samples
[:,
-
1
]
!=
0
).
nonzero
(
as_tuple
=
True
)[
0
]
valid_idx
=
(
samples
[:,
-
1
]
!=
0
).
nonzero
(
as_tuple
=
True
)[
0
]
remain_idx
=
remain_idx
[
valid_idx
]
remain_idx
=
remain_idx
[
valid_idx
]
cache
=
cache
[:,:,
valid_idx
]
cache
=
cache
[:,:,
valid_idx
]
global_context
=
global_context
[
valid_idx
]
global_context
=
global_context
[
valid_idx
]
seq_context
=
seq_context
[
valid_idx
]
seq_context
=
seq_context
[
valid_idx
]
samples
=
samples
[
valid_idx
]
samples
=
samples
[
valid_idx
]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens
=
decoded_tokens
[:,:
i
+
1
]
decoded_tokens
=
decoded_tokens
[:,:
i
+
1
]
outputs
=
self
.
post_process
(
decoded_tokens
,
seq_len
,
outputs
=
self
.
post_process
(
decoded_tokens
,
seq_len
,
device
,
only_return_complete
)
device
,
only_return_complete
)
return
outputs
return
outputs
def
post_process
(
self
,
polyline
,
def
post_process
(
self
,
polyline
,
max_seq_len
=
None
,
max_seq_len
=
None
,
device
=
None
,
device
=
None
,
only_return_complete
=
True
):
only_return_complete
=
True
):
'''
'''
format the predictions
format the predictions
find the mask
find the mask
'''
'''
# Record completed samples
# Record completed samples
complete_samples
=
(
polyline
==
0
).
any
(
-
1
)
complete_samples
=
(
polyline
==
0
).
any
(
-
1
)
# Find number of faces
# Find number of faces
sample_seq_length
=
polyline
.
shape
[
-
1
]
sample_seq_length
=
polyline
.
shape
[
-
1
]
_polyline_mask
=
torch
.
arange
(
sample_seq_length
)[
None
].
to
(
device
)
_polyline_mask
=
torch
.
arange
(
sample_seq_length
)[
None
].
to
(
device
)
# Get largest stopping point for incomplete samples.
# Get largest stopping point for incomplete samples.
valid_polyline_len
=
torch
.
full_like
(
polyline
[:,
0
],
sample_seq_length
)
valid_polyline_len
=
torch
.
full_like
(
polyline
[:,
0
],
sample_seq_length
)
zero_inds
=
(
polyline
==
0
).
type
(
torch
.
int32
).
argmax
(
-
1
)
zero_inds
=
(
polyline
==
0
).
type
(
torch
.
int32
).
argmax
(
-
1
)
# Real length
# Real length
valid_polyline_len
[
complete_samples
]
=
zero_inds
[
complete_samples
]
+
1
valid_polyline_len
[
complete_samples
]
=
zero_inds
[
complete_samples
]
+
1
polyline_mask
=
_polyline_mask
<
valid_polyline_len
[:,
None
]
polyline_mask
=
_polyline_mask
<
valid_polyline_len
[:,
None
]
# Mask faces beyond stopping token with zeros
# Mask faces beyond stopping token with zeros
polyline
=
polyline
*
polyline_mask
polyline
=
polyline
*
polyline_mask
# Pad to maximum size with zeros
# Pad to maximum size with zeros
pad_size
=
max_seq_len
-
sample_seq_length
pad_size
=
max_seq_len
-
sample_seq_length
polyline
=
F
.
pad
(
polyline
,
(
0
,
pad_size
))
polyline
=
F
.
pad
(
polyline
,
(
0
,
pad_size
))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# XXX
# if only_return_complete:
# if only_return_complete:
# polyline = polyline[complete_samples]
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# complete_samples = complete_samples[complete_samples]
# outputs
# outputs
outputs
=
{
outputs
=
{
'completed'
:
complete_samples
,
'completed'
:
complete_samples
,
'polylines'
:
polyline
,
'polylines'
:
polyline
,
'polyline_masks'
:
polyline_mask
,
'polyline_masks'
:
polyline_mask
,
}
}
return
outputs
return
outputs
def
find_best_sperate_plan
(
idx
,
array
):
def
find_best_sperate_plan
(
idx
,
array
):
h
=
array
[
-
1
]
-
array
[
idx
]
h
=
array
[
-
1
]
-
array
[
idx
]
w
=
idx
w
=
idx
cost
=
h
*
w
cost
=
h
*
w
return
cost
return
cost
def
get_chunk_idx
(
polyline_length
):
def
get_chunk_idx
(
polyline_length
):
_polyline_length
,
polyline_length_idx
=
torch
.
sort
(
polyline_length
)
_polyline_length
,
polyline_length_idx
=
torch
.
sort
(
polyline_length
)
costs
=
[]
costs
=
[]
for
i
in
range
(
len
(
_polyline_length
)):
for
i
in
range
(
len
(
_polyline_length
)):
cost
=
find_best_sperate_plan
(
i
,
_polyline_length
)
cost
=
find_best_sperate_plan
(
i
,
_polyline_length
)
costs
.
append
(
cost
)
costs
.
append
(
cost
)
seperate_point
=
torch
.
stack
(
costs
).
argmax
()
seperate_point
=
torch
.
stack
(
costs
).
argmax
()
chunk1
=
polyline_length_idx
[:
seperate_point
+
1
]
chunk1
=
polyline_length_idx
[:
seperate_point
+
1
]
chunk2
=
polyline_length_idx
[
seperate_point
+
1
:]
chunk2
=
polyline_length_idx
[
seperate_point
+
1
:]
revert_idx
=
torch
.
argsort
(
polyline_length_idx
)
revert_idx
=
torch
.
argsort
(
polyline_length_idx
)
return
chunk1
,
chunk2
,
revert_idx
,
_polyline_length
[
seperate_point
]
return
chunk1
,
chunk2
,
revert_idx
,
_polyline_length
[
seperate_point
]
def
assign_bev
(
feat
,
idx
):
def
assign_bev
(
feat
,
idx
):
return
feat
[
idx
]
return
feat
[
idx
]
def
assign_batch
(
batch
,
idx
,
size
):
def
assign_batch
(
batch
,
idx
,
size
):
new_batch
=
{}
new_batch
=
{}
for
k
,
v
in
batch
.
items
():
for
k
,
v
in
batch
.
items
():
new_batch
[
k
]
=
v
[
idx
]
new_batch
[
k
]
=
v
[
idx
]
if
new_batch
[
k
].
ndim
>
1
:
if
new_batch
[
k
].
ndim
>
1
:
new_batch
[
k
]
=
new_batch
[
k
][:,:
size
]
new_batch
[
k
]
=
new_batch
[
k
][:,:
size
]
return
new_batch
return
new_batch
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/losses/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/losses/__init__.py
View file @
f3b13cad
from
.detr_loss
import
LinesLoss
,
MasksLoss
,
LenLoss
from
.detr_loss
import
LinesLoss
,
MasksLoss
,
LenLoss
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/losses/detr_loss.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/losses/detr_loss.py
View file @
f3b13cad
import
torch
import
torch
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet.models.losses
import
l1_loss
from
mmdet.models.losses
import
l1_loss
from
mmdet.models.losses.utils
import
weighted_loss
from
mmdet.models.losses.utils
import
weighted_loss
import
mmcv
import
mmcv
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.builder
import
LOSSES
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
@
weighted_loss
def
smooth_l1_loss
(
pred
,
target
,
beta
=
1.0
):
def
smooth_l1_loss
(
pred
,
target
,
beta
=
1.0
):
"""Smooth L1 loss.
"""Smooth L1 loss.
Args:
Args:
pred (torch.Tensor): The prediction.
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Defaults to 1.0.
Returns:
Returns:
torch.Tensor: Calculated loss
torch.Tensor: Calculated loss
"""
"""
assert
beta
>
0
assert
beta
>
0
if
target
.
numel
()
==
0
:
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
target
.
size
()
assert
pred
.
size
()
==
target
.
size
()
diff
=
torch
.
abs
(
pred
-
target
)
diff
=
torch
.
abs
(
pred
-
target
)
loss
=
torch
.
where
(
diff
<
beta
,
0.5
*
diff
*
diff
/
beta
,
loss
=
torch
.
where
(
diff
<
beta
,
0.5
*
diff
*
diff
/
beta
,
diff
-
0.5
*
beta
)
diff
-
0.5
*
beta
)
return
loss
return
loss
@
LOSSES
.
register_module
()
@
LOSSES
.
register_module
()
class
LinesLoss
(
nn
.
Module
):
class
LinesLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
,
beta
=
0.5
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
,
beta
=
0.5
):
"""
"""
L1 loss. The same as the smooth L1 loss
L1 loss. The same as the smooth L1 loss
Args:
Args:
reduction (str, optional): The method to reduce the loss.
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
loss_weight (float, optional): The weight of loss.
"""
"""
super
(
LinesLoss
,
self
).
__init__
()
super
(
LinesLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
loss_weight
=
loss_weight
self
.
beta
=
beta
self
.
beta
=
beta
def
forward
(
self
,
def
forward
(
self
,
pred
,
pred
,
target
,
target
,
weight
=
None
,
weight
=
None
,
avg_factor
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
reduction_override
=
None
):
"""Forward function.
"""Forward function.
Args:
Args:
pred (torch.Tensor): The prediction.
pred (torch.Tensor): The prediction.
shape: [bs, ...]
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
prediction. Defaults to None.
it's useful when the predictions are not all valid.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
override the original reduction method of the loss.
Defaults to None.
Defaults to None.
"""
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
smooth_l1_loss
(
loss
=
smooth_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
beta
=
self
.
beta
)
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
beta
=
self
.
beta
)
return
loss
*
self
.
loss_weight
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
@
weighted_loss
def
bce
(
pred
,
label
,
class_weight
=
None
):
def
bce
(
pred
,
label
,
class_weight
=
None
):
"""
"""
pred: B,nquery,npts
pred: B,nquery,npts
label: B,nquery,npts
label: B,nquery,npts
"""
"""
if
label
.
numel
()
==
0
:
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
label
.
size
()
assert
pred
.
size
()
==
label
.
size
()
loss
=
F
.
binary_cross_entropy_with_logits
(
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
return
loss
return
loss
@
LOSSES
.
register_module
()
@
LOSSES
.
register_module
()
class
MasksLoss
(
nn
.
Module
):
class
MasksLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
MasksLoss
,
self
).
__init__
()
super
(
MasksLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
def
forward
(
self
,
pred
,
pred
,
target
,
target
,
weight
=
None
,
weight
=
None
,
avg_factor
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
reduction_override
=
None
):
"""Forward function.
"""Forward function.
Args:
Args:
xxx
xxx
"""
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
bce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
loss
=
bce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
avg_factor
=
avg_factor
)
return
loss
*
self
.
loss_weight
return
loss
*
self
.
loss_weight
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
@
weighted_loss
def
ce
(
pred
,
label
,
class_weight
=
None
):
def
ce
(
pred
,
label
,
class_weight
=
None
):
"""
"""
pred: B*nquery,npts
pred: B*nquery,npts
label: B*nquery,
label: B*nquery,
"""
"""
if
label
.
numel
()
==
0
:
if
label
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
return
pred
.
sum
()
*
0
loss
=
F
.
cross_entropy
(
loss
=
F
.
cross_entropy
(
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
)
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
)
return
loss
return
loss
@
LOSSES
.
register_module
()
@
LOSSES
.
register_module
()
class
LenLoss
(
nn
.
Module
):
class
LenLoss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
LenLoss
,
self
).
__init__
()
super
(
LenLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
def
forward
(
self
,
pred
,
pred
,
target
,
target
,
weight
=
None
,
weight
=
None
,
avg_factor
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
reduction_override
=
None
):
"""Forward function.
"""Forward function.
Args:
Args:
xxx
xxx
"""
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
reduction_override
if
reduction_override
else
self
.
reduction
)
loss
=
ce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
loss
=
ce
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
avg_factor
=
avg_factor
)
return
loss
*
self
.
loss_weight
return
loss
*
self
.
loss_weight
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/__init__.py
View file @
f3b13cad
from
.vectormapnet
import
VectorMapNet
from
.vectormapnet
import
VectorMapNet
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/base_mapper.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/base_mapper.py
View file @
f3b13cad
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.runner
import
auto_fp16
from
mmcv.runner
import
auto_fp16
from
mmcv.utils
import
print_log
from
mmcv.utils
import
print_log
from
mmdet.utils
import
get_root_logger
from
mmdet.utils
import
get_root_logger
from
mmdet3d.models.builder
import
DETECTORS
from
mmdet3d.models.builder
import
DETECTORS
MAPPERS
=
DETECTORS
MAPPERS
=
DETECTORS
class
BaseMapper
(
nn
.
Module
,
metaclass
=
ABCMeta
):
class
BaseMapper
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""Base class for mappers."""
"""Base class for mappers."""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
BaseMapper
,
self
).
__init__
()
super
(
BaseMapper
,
self
).
__init__
()
self
.
fp16_enabled
=
False
self
.
fp16_enabled
=
False
@
property
@
property
def
with_neck
(
self
):
def
with_neck
(
self
):
"""bool: whether the detector has a neck"""
"""bool: whether the detector has a neck"""
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
# TODO: these properties need to be carefully handled
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
# for both single stage & two stage detectors
@
property
@
property
def
with_shared_head
(
self
):
def
with_shared_head
(
self
):
"""bool: whether the detector has a shared head in the RoI Head"""
"""bool: whether the detector has a shared head in the RoI Head"""
return
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_shared_head
return
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_shared_head
@
property
@
property
def
with_bbox
(
self
):
def
with_bbox
(
self
):
"""bool: whether the detector has a bbox head"""
"""bool: whether the detector has a bbox head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_bbox
)
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_bbox
)
or
(
hasattr
(
self
,
'bbox_head'
)
and
self
.
bbox_head
is
not
None
))
or
(
hasattr
(
self
,
'bbox_head'
)
and
self
.
bbox_head
is
not
None
))
@
property
@
property
def
with_mask
(
self
):
def
with_mask
(
self
):
"""bool: whether the detector has a mask head"""
"""bool: whether the detector has a mask head"""
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_mask
)
return
((
hasattr
(
self
,
'roi_head'
)
and
self
.
roi_head
.
with_mask
)
or
(
hasattr
(
self
,
'mask_head'
)
and
self
.
mask_head
is
not
None
))
or
(
hasattr
(
self
,
'mask_head'
)
and
self
.
mask_head
is
not
None
))
#@abstractmethod
#@abstractmethod
def
extract_feat
(
self
,
imgs
):
def
extract_feat
(
self
,
imgs
):
"""Extract features from images."""
"""Extract features from images."""
pass
pass
def
forward_train
(
self
,
*
args
,
**
kwargs
):
def
forward_train
(
self
,
*
args
,
**
kwargs
):
pass
pass
#@abstractmethod
#@abstractmethod
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
pass
pass
#@abstractmethod
#@abstractmethod
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
"""Test function with test time augmentation."""
"""Test function with test time augmentation."""
pass
pass
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
"""Initialize the weights in detector.
"""Initialize the weights in detector.
Args:
Args:
pretrained (str, optional): Path to pre-trained weights.
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
Defaults to None.
"""
"""
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
logger
=
get_root_logger
()
logger
=
get_root_logger
()
print_log
(
f
'load model from:
{
pretrained
}
'
,
logger
=
logger
)
print_log
(
f
'load model from:
{
pretrained
}
'
,
logger
=
logger
)
def
forward_test
(
self
,
*
args
,
**
kwargs
):
def
forward_test
(
self
,
*
args
,
**
kwargs
):
"""
"""
Args:
Args:
"""
"""
if
True
:
if
True
:
self
.
simple_test
()
self
.
simple_test
()
else
:
else
:
self
.
aug_test
()
self
.
aug_test
()
# @auto_fp16(apply_to=('img', ))
# @auto_fp16(apply_to=('img', ))
def
forward
(
self
,
*
args
,
return_loss
=
True
,
**
kwargs
):
def
forward
(
self
,
*
args
,
return_loss
=
True
,
**
kwargs
):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
the outer list indicating test time augmentations.
"""
"""
if
return_loss
:
if
return_loss
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
else
:
kwargs
.
pop
(
'rescale'
)
kwargs
.
pop
(
'rescale'
)
return
self
.
forward_test
(
*
args
,
**
kwargs
)
return
self
.
forward_test
(
*
args
,
**
kwargs
)
def
train_step
(
self
,
data_dict
,
optimizer
):
def
train_step
(
self
,
data_dict
,
optimizer
):
"""The iteration step during training.
"""The iteration step during training.
This method defines an iteration step during training, except for the
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
including back propagation and optimizer updating is also defined in
this method, such as GAN.
this method, such as GAN.
Args:
Args:
data_dict (dict): The output of dataloader.
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
runner is passed to ``train_step()``. This argument is unused
and reserved.
and reserved.
Returns:
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
\
``num_samples``.
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
- ``loss`` is a tensor for back propagation, which can be a
\
weighted sum of multiple losses.
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
- ``log_vars`` contains all the variables to be sent to the
logger.
logger.
- ``num_samples`` indicates the batch size (when the model is \
- ``num_samples`` indicates the batch size (when the model is
\
DDP, it means the batch size on each GPU), which is used for \
DDP, it means the batch size on each GPU), which is used for
\
averaging the logs.
averaging the logs.
"""
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data_dict
)
loss
,
log_vars
,
num_samples
=
self
(
**
data_dict
)
outputs
=
dict
(
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
return
outputs
def
val_step
(
self
,
data
,
optimizer
):
def
val_step
(
self
,
data
,
optimizer
):
"""The iteration step during validation.
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
not implemented with this method, but an evaluation hook.
"""
"""
loss
,
log_vars
,
num_samples
=
self
(
**
data
)
loss
,
log_vars
,
num_samples
=
self
(
**
data
)
outputs
=
dict
(
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
return
outputs
def
show_result
(
self
,
def
show_result
(
self
,
**
kwargs
):
**
kwargs
):
img
=
None
img
=
None
return
img
return
img
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/mapers/vectormapnet.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/mapers/vectormapnet.py
View file @
f3b13cad
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.nn.utils.rnn
import
pad_sequence
from
torchvision.models.resnet
import
resnet18
,
resnet50
from
torchvision.models.resnet
import
resnet18
,
resnet50
from
mmdet3d.models.builder
import
(
build_backbone
,
build_head
,
from
mmdet3d.models.builder
import
(
build_backbone
,
build_head
,
build_neck
)
build_neck
)
from
.base_mapper
import
BaseMapper
,
MAPPERS
from
.base_mapper
import
BaseMapper
,
MAPPERS
@
MAPPERS
.
register_module
()
@
MAPPERS
.
register_module
()
class
VectorMapNet
(
BaseMapper
):
class
VectorMapNet
(
BaseMapper
):
def
__init__
(
self
,
def
__init__
(
self
,
backbone_cfg
=
dict
(),
backbone_cfg
=
dict
(),
head_cfg
=
dict
(
head_cfg
=
dict
(
vert_net_cfg
=
dict
(),
vert_net_cfg
=
dict
(),
face_net_cfg
=
dict
(),
face_net_cfg
=
dict
(),
),
),
neck_input_channels
=
128
,
neck_input_channels
=
128
,
neck_cfg
=
None
,
neck_cfg
=
None
,
with_auxiliary_head
=
False
,
with_auxiliary_head
=
False
,
only_det
=
False
,
only_det
=
False
,
train_cfg
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
pretrained
=
None
,
model_name
=
None
,
**
kwargs
):
model_name
=
None
,
**
kwargs
):
super
(
VectorMapNet
,
self
).
__init__
()
super
(
VectorMapNet
,
self
).
__init__
()
#Attribute
#Attribute
self
.
model_name
=
model_name
self
.
model_name
=
model_name
self
.
last_epoch
=
None
self
.
last_epoch
=
None
self
.
only_det
=
only_det
self
.
only_det
=
only_det
self
.
backbone
=
build_backbone
(
backbone_cfg
)
self
.
backbone
=
build_backbone
(
backbone_cfg
)
if
neck_cfg
is
not
None
:
if
neck_cfg
is
not
None
:
self
.
neck_neck
=
build_backbone
(
neck_cfg
.
backbone
)
self
.
neck_neck
=
build_backbone
(
neck_cfg
.
backbone
)
self
.
neck_neck
.
conv1
=
nn
.
Conv2d
(
self
.
neck_neck
.
conv1
=
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
neck_input_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
neck_project
=
build_neck
(
neck_cfg
.
neck
)
self
.
neck_project
=
build_neck
(
neck_cfg
.
neck
)
self
.
neck
=
self
.
multiscale_neck
self
.
neck
=
self
.
multiscale_neck
else
:
else
:
trunk
=
resnet18
(
pretrained
=
False
,
zero_init_residual
=
True
)
trunk
=
resnet18
(
pretrained
=
False
,
zero_init_residual
=
True
)
self
.
neck
=
nn
.
Sequential
(
self
.
neck
=
nn
.
Sequential
(
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
(
7
,
7
),
stride
=
(
nn
.
Conv2d
(
neck_input_channels
,
64
,
kernel_size
=
(
7
,
7
),
stride
=
(
2
,
2
),
padding
=
(
3
,
3
),
bias
=
False
),
2
,
2
),
padding
=
(
3
,
3
),
bias
=
False
),
nn
.
BatchNorm2d
(
64
),
nn
.
BatchNorm2d
(
64
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
dilation
=
1
,
ceil_mode
=
False
),
dilation
=
1
,
ceil_mode
=
False
),
trunk
.
layer1
,
trunk
.
layer1
,
nn
.
Conv2d
(
64
,
128
,
kernel_size
=
1
,
bias
=
False
),
nn
.
Conv2d
(
64
,
128
,
kernel_size
=
1
,
bias
=
False
),
)
)
# BEV
# BEV
if
hasattr
(
self
.
backbone
,
'bev_w'
):
if
hasattr
(
self
.
backbone
,
'bev_w'
):
self
.
bev_w
=
self
.
backbone
.
bev_w
self
.
bev_w
=
self
.
backbone
.
bev_w
self
.
bev_h
=
self
.
backbone
.
bev_h
self
.
bev_h
=
self
.
backbone
.
bev_h
self
.
head
=
build_head
(
head_cfg
)
self
.
head
=
build_head
(
head_cfg
)
def
multiscale_neck
(
self
,
bev_embedding
):
def
multiscale_neck
(
self
,
bev_embedding
):
multi_feat
=
self
.
neck_neck
(
bev_embedding
)
multi_feat
=
self
.
neck_neck
(
bev_embedding
)
multi_feat
=
self
.
neck_project
(
multi_feat
)
multi_feat
=
self
.
neck_project
(
multi_feat
)
return
multi_feat
return
multi_feat
def
forward_train
(
self
,
img
,
polys
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
def
forward_train
(
self
,
img
,
polys
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
'''
Args:
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- lines: np.array of shape [num_points, 2].
- length: int
- length: int
- label: int
- label: int
len(vectors) = batch_size
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
img_metas['lidar2img']: [B, N, 4, 4]
Out:
Out:
loss, log_vars, num_sample
loss, log_vars, num_sample
'''
'''
# prepare labels and images
# prepare labels and images
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
batch_data
(
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
batch_data
(
polys
,
img
,
img_metas
,
img
.
device
,
points
)
polys
,
img
,
img_metas
,
img
.
device
,
points
)
# corner cases use hard code to prevent code fail
# corner cases use hard code to prevent code fail
if
self
.
last_epoch
is
None
:
if
self
.
last_epoch
is
None
:
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
if
len
(
valid_idx
)
==
0
:
if
len
(
valid_idx
)
==
0
:
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
last_epoch
batch
,
img
,
img_metas
,
valid_idx
,
points
=
self
.
last_epoch
else
:
else
:
del
self
.
last_epoch
del
self
.
last_epoch
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
self
.
last_epoch
=
[
batch
,
img
,
img_metas
,
valid_idx
,
points
]
# Backbone
# Backbone
_bev_feats
=
self
.
backbone
(
img
,
img_metas
=
img_metas
,
points
=
points
)
_bev_feats
=
self
.
backbone
(
img
,
img_metas
=
img_metas
,
points
=
points
)
img_shape
=
\
img_shape
=
\
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
bev_feats
=
self
.
neck
(
_bev_feats
)
preds_dict
,
losses_dict
=
\
preds_dict
,
losses_dict
=
\
self
.
head
(
batch
,
self
.
head
(
batch
,
context
=
{
context
=
{
'bev_embeddings'
:
bev_feats
,
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
'img_shape'
:
img_shape
,
'raw_bev_embeddings'
:
_bev_feats
},
'raw_bev_embeddings'
:
_bev_feats
},
only_det
=
self
.
only_det
)
only_det
=
self
.
only_det
)
# format outputs
# format outputs
loss
=
0
loss
=
0
for
name
,
var
in
losses_dict
.
items
():
for
name
,
var
in
losses_dict
.
items
():
loss
=
loss
+
var
loss
=
loss
+
var
# update the log
# update the log
log_vars
=
{
k
:
v
.
item
()
for
k
,
v
in
losses_dict
.
items
()}
log_vars
=
{
k
:
v
.
item
()
for
k
,
v
in
losses_dict
.
items
()}
log_vars
.
update
({
'total'
:
loss
.
item
()})
log_vars
.
update
({
'total'
:
loss
.
item
()})
num_sample
=
img
.
size
(
0
)
num_sample
=
img
.
size
(
0
)
return
loss
,
log_vars
,
num_sample
return
loss
,
log_vars
,
num_sample
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_test
(
self
,
img
,
polys
=
None
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
def
forward_test
(
self
,
img
,
polys
=
None
,
points
=
None
,
img_metas
=
None
,
**
kwargs
):
'''
'''
inference pipeline
inference pipeline
'''
'''
# prepare labels and images
# prepare labels and images
token
=
[]
token
=
[]
for
img_meta
in
img_metas
:
for
img_meta
in
img_metas
:
token
.
append
(
img_meta
[
'token'
])
token
.
append
(
img_meta
[
'token'
])
_bev_feats
=
self
.
backbone
(
img
,
img_metas
,
points
=
points
)
_bev_feats
=
self
.
backbone
(
img
,
img_metas
,
points
=
points
)
img_shape
=
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
img_shape
=
[
_bev_feats
.
shape
[
2
:]
for
i
in
range
(
_bev_feats
.
shape
[
0
])]
# Neck
# Neck
bev_feats
=
self
.
neck
(
_bev_feats
)
bev_feats
=
self
.
neck
(
_bev_feats
)
context
=
{
'bev_embeddings'
:
bev_feats
,
context
=
{
'bev_embeddings'
:
bev_feats
,
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'batch_input_shape'
:
_bev_feats
.
shape
[
2
:],
'img_shape'
:
img_shape
,
# XXX
'img_shape'
:
img_shape
,
# XXX
'raw_bev_embeddings'
:
_bev_feats
}
'raw_bev_embeddings'
:
_bev_feats
}
preds_dict
=
self
.
head
(
batch
=
{},
preds_dict
=
self
.
head
(
batch
=
{},
context
=
context
,
context
=
context
,
condition_on_det
=
True
,
condition_on_det
=
True
,
gt_condition
=
False
,
gt_condition
=
False
,
only_det
=
self
.
only_det
)
only_det
=
self
.
only_det
)
# Hard Code
# Hard Code
if
preds_dict
is
None
:
if
preds_dict
is
None
:
return
[
None
]
return
[
None
]
results_list
=
self
.
head
.
post_process
(
preds_dict
,
token
,
only_det
=
self
.
only_det
)
results_list
=
self
.
head
.
post_process
(
preds_dict
,
token
,
only_det
=
self
.
only_det
)
return
results_list
return
results_list
def
batch_data
(
self
,
polys
,
imgs
,
img_metas
,
device
,
points
=
None
):
def
batch_data
(
self
,
polys
,
imgs
,
img_metas
,
device
,
points
=
None
):
# filter none vector's case
# filter none vector's case
valid_idx
=
[
i
for
i
in
range
(
len
(
polys
))
if
len
(
polys
[
i
])]
valid_idx
=
[
i
for
i
in
range
(
len
(
polys
))
if
len
(
polys
[
i
])]
imgs
=
imgs
[
valid_idx
]
imgs
=
imgs
[
valid_idx
]
img_metas
=
[
img_metas
[
i
]
for
i
in
valid_idx
]
img_metas
=
[
img_metas
[
i
]
for
i
in
valid_idx
]
polys
=
[
polys
[
i
]
for
i
in
valid_idx
]
polys
=
[
polys
[
i
]
for
i
in
valid_idx
]
if
points
is
not
None
:
if
points
is
not
None
:
points
=
[
points
[
i
]
for
i
in
valid_idx
]
points
=
[
points
[
i
]
for
i
in
valid_idx
]
points
=
self
.
batch_points
(
points
)
points
=
self
.
batch_points
(
points
)
if
len
(
valid_idx
)
==
0
:
if
len
(
valid_idx
)
==
0
:
return
None
,
None
,
None
,
valid_idx
,
None
return
None
,
None
,
None
,
valid_idx
,
None
batch
=
{}
batch
=
{}
batch
[
'det'
]
=
format_det
(
polys
,
device
)
batch
[
'det'
]
=
format_det
(
polys
,
device
)
batch
[
'gen'
]
=
format_gen
(
polys
,
device
)
batch
[
'gen'
]
=
format_gen
(
polys
,
device
)
return
batch
,
imgs
,
img_metas
,
valid_idx
,
points
return
batch
,
imgs
,
img_metas
,
valid_idx
,
points
def
batch_points
(
self
,
points
):
def
batch_points
(
self
,
points
):
pad_points
=
pad_sequence
(
points
,
batch_first
=
True
)
pad_points
=
pad_sequence
(
points
,
batch_first
=
True
)
points_mask
=
torch
.
zeros_like
(
pad_points
[:,:,
0
]).
bool
()
points_mask
=
torch
.
zeros_like
(
pad_points
[:,:,
0
]).
bool
()
for
i
in
range
(
len
(
points
)):
for
i
in
range
(
len
(
points
)):
valid_num
=
points
[
i
].
shape
[
0
]
valid_num
=
points
[
i
].
shape
[
0
]
points_mask
[
i
][:
valid_num
]
=
True
points_mask
[
i
][:
valid_num
]
=
True
return
(
pad_points
,
points_mask
)
return
(
pad_points
,
points_mask
)
def
format_det
(
polys
,
device
):
def
format_det
(
polys
,
device
):
batch
=
{
batch
=
{
'class_label'
:[],
'class_label'
:[],
'batch_idx'
:[],
'batch_idx'
:[],
'bbox'
:
[],
'bbox'
:
[],
}
}
for
batch_idx
,
poly
in
enumerate
(
polys
):
for
batch_idx
,
poly
in
enumerate
(
polys
):
keypoint_label
=
torch
.
from_numpy
(
poly
[
'det_label'
]).
to
(
device
)
keypoint_label
=
torch
.
from_numpy
(
poly
[
'det_label'
]).
to
(
device
)
keypoint
=
torch
.
from_numpy
(
poly
[
'keypoint'
]).
to
(
device
)
keypoint
=
torch
.
from_numpy
(
poly
[
'keypoint'
]).
to
(
device
)
batch
[
'class_label'
].
append
(
keypoint_label
)
batch
[
'class_label'
].
append
(
keypoint_label
)
batch
[
'bbox'
].
append
(
keypoint
)
batch
[
'bbox'
].
append
(
keypoint
)
return
batch
return
batch
def
format_gen
(
polys
,
device
):
def
format_gen
(
polys
,
device
):
line_cls
=
[]
line_cls
=
[]
polylines
,
polyline_masks
,
polyline_weights
=
[],
[],
[]
polylines
,
polyline_masks
,
polyline_weights
=
[],
[],
[]
bbox
,
line_cls
,
line_bs_idx
=
[],
[],
[]
bbox
,
line_cls
,
line_bs_idx
=
[],
[],
[]
for
batch_idx
,
poly
in
enumerate
(
polys
):
for
batch_idx
,
poly
in
enumerate
(
polys
):
# convert to cuda tensor
# convert to cuda tensor
for
k
in
poly
.
keys
():
for
k
in
poly
.
keys
():
if
isinstance
(
poly
[
k
],
np
.
ndarray
):
if
isinstance
(
poly
[
k
],
np
.
ndarray
):
poly
[
k
]
=
torch
.
from_numpy
(
poly
[
k
]).
to
(
device
)
poly
[
k
]
=
torch
.
from_numpy
(
poly
[
k
]).
to
(
device
)
else
:
else
:
poly
[
k
]
=
[
torch
.
from_numpy
(
v
).
to
(
device
)
for
v
in
poly
[
k
]]
poly
[
k
]
=
[
torch
.
from_numpy
(
v
).
to
(
device
)
for
v
in
poly
[
k
]]
line_cls
+=
poly
[
'gen_label'
]
line_cls
+=
poly
[
'gen_label'
]
line_bs_idx
+=
[
batch_idx
]
*
len
(
poly
[
'gen_label'
])
line_bs_idx
+=
[
batch_idx
]
*
len
(
poly
[
'gen_label'
])
# condition
# condition
bbox
+=
poly
[
'qkeypoint'
]
bbox
+=
poly
[
'qkeypoint'
]
# out
# out
polylines
+=
poly
[
'polylines'
]
polylines
+=
poly
[
'polylines'
]
polyline_masks
+=
poly
[
'polyline_masks'
]
polyline_masks
+=
poly
[
'polyline_masks'
]
polyline_weights
+=
poly
[
'polyline_weights'
]
polyline_weights
+=
poly
[
'polyline_weights'
]
batch
=
{}
batch
=
{}
batch
[
'lines_bs_idx'
]
=
torch
.
tensor
(
batch
[
'lines_bs_idx'
]
=
torch
.
tensor
(
line_bs_idx
,
dtype
=
torch
.
long
,
device
=
device
)
line_bs_idx
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'lines_cls'
]
=
torch
.
tensor
(
batch
[
'lines_cls'
]
=
torch
.
tensor
(
line_cls
,
dtype
=
torch
.
long
,
device
=
device
)
line_cls
,
dtype
=
torch
.
long
,
device
=
device
)
batch
[
'bbox_flat'
]
=
torch
.
stack
(
bbox
,
0
)
batch
[
'bbox_flat'
]
=
torch
.
stack
(
bbox
,
0
)
# padding
# padding
batch
[
'polylines'
]
=
pad_sequence
(
polylines
,
batch_first
=
True
)
batch
[
'polylines'
]
=
pad_sequence
(
polylines
,
batch_first
=
True
)
batch
[
'polyline_masks'
]
=
pad_sequence
(
polyline_masks
,
batch_first
=
True
)
batch
[
'polyline_masks'
]
=
pad_sequence
(
polyline_masks
,
batch_first
=
True
)
batch
[
'polyline_weights'
]
=
pad_sequence
(
polyline_weights
,
batch_first
=
True
)
batch
[
'polyline_weights'
]
=
pad_sequence
(
polyline_weights
,
batch_first
=
True
)
return
batch
return
batch
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/__init__.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/__init__.py
View file @
f3b13cad
from
.deformable_transformer
import
DeformableDetrTransformer_
,
DeformableDetrTransformerDecoder_
from
.deformable_transformer
import
DeformableDetrTransformer_
,
DeformableDetrTransformerDecoder_
from
.base_transformer
import
PlaceHolderEncoder
from
.base_transformer
import
PlaceHolderEncoder
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/base_transformer.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/base_transformer.py
View file @
f3b13cad
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
MultiScaleDeformableAttention
,
from
mmcv.cnn.bricks.transformer
import
(
MultiScaleDeformableAttention
,
TransformerLayerSequence
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet.models.utils.builder
import
TRANSFORMER
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
PlaceHolderEncoder
(
nn
.
Module
):
class
PlaceHolderEncoder
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
embed_dims
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
embed_dims
=
None
,
**
kwargs
):
super
(
PlaceHolderEncoder
,
self
).
__init__
()
super
(
PlaceHolderEncoder
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
def
forward
(
self
,
*
args
,
query
=
None
,
**
kwargs
):
def
forward
(
self
,
*
args
,
query
=
None
,
**
kwargs
):
return
query
return
query
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/deformable_transformer.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/deformable_transformer.py
View file @
f3b13cad
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
math
import
warnings
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
build_activation_layer
,
build_norm_layer
,
xavier_init
from
mmcv.cnn
import
build_activation_layer
,
build_norm_layer
,
xavier_init
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
TransformerLayerSequence
,
TransformerLayerSequence
,
build_transformer_layer_sequence
)
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
mmcv.runner.base_module
import
BaseModule
from
torch.nn.init
import
normal_
from
torch.nn.init
import
normal_
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet.models.utils.transformer
import
Transformer
from
mmdet.models.utils.transformer
import
Transformer
try
:
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
except
ImportError
:
warnings
.
warn
(
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
.fp16_dattn
import
MultiScaleDeformableAttentionFp16
from
.fp16_dattn
import
MultiScaleDeformableAttentionFp16
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
"""Inverse function of sigmoid.
Args:
Args:
x (Tensor): The tensor to do the
x (Tensor): The tensor to do the
inverse.
inverse.
eps (float): EPS avoid numerical
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
overflow. Defaults 1e-5.
Returns:
Returns:
Tensor: The x has passed the inverse
Tensor: The x has passed the inverse
function of sigmoid, has same
function of sigmoid, has same
shape with input.
shape with input.
"""
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DeformableDetrTransformerDecoder_
(
TransformerLayerSequence
):
class
DeformableDetrTransformerDecoder_
(
TransformerLayerSequence
):
"""Implements the decoder in DETR transformer.
"""Implements the decoder in DETR transformer.
Args:
Args:
return_intermediate (bool): Whether to return intermediate outputs.
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
`LN`.
"""
"""
def
__init__
(
self
,
*
args
,
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
**
kwargs
):
return_intermediate
=
False
,
coord_dim
=
2
,
kp_coord_dim
=
2
,
**
kwargs
):
super
(
DeformableDetrTransformerDecoder_
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
DeformableDetrTransformerDecoder_
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
return_intermediate
=
return_intermediate
self
.
coord_dim
=
coord_dim
self
.
coord_dim
=
coord_dim
self
.
kp_coord_dim
=
kp_coord_dim
self
.
kp_coord_dim
=
kp_coord_dim
def
forward
(
self
,
def
forward
(
self
,
query
,
query
,
*
args
,
*
args
,
reference_points
=
None
,
reference_points
=
None
,
valid_ratios
=
None
,
valid_ratios
=
None
,
reg_branches
=
None
,
reg_branches
=
None
,
**
kwargs
):
**
kwargs
):
"""Forward function for `TransformerDecoder`.
"""Forward function for `TransformerDecoder`.
Args:
Args:
query (Tensor): Input query with shape
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
reference_points (Tensor): The reference
points of offset. has shape
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
points on the feature map, has shape
(bs, num_levels, 2)
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
refining the regression results. Only would
be passed when with_box_refine is True,
be passed when with_box_refine is True,
otherwise would be passed a `None`.
otherwise would be passed a `None`.
Returns:
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
[num_layers, num_query, bs, embed_dims].
"""
"""
output
=
query
output
=
query
intermediate
=
[]
intermediate
=
[]
intermediate_reference_points
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
\
reference_points_input
=
\
reference_points
[:,
:,
None
,:
self
.
kp_coord_dim
]
*
\
reference_points
[:,
:,
None
,:
self
.
kp_coord_dim
]
*
\
valid_ratios
[:,
None
,:,:
self
.
kp_coord_dim
]
valid_ratios
[:,
None
,:,:
self
.
kp_coord_dim
]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output
=
layer
(
output
=
layer
(
output
,
output
,
*
args
,
*
args
,
reference_points
=
reference_points_input
[...,:
self
.
kp_coord_dim
],
reference_points
=
reference_points_input
[...,:
self
.
kp_coord_dim
],
**
kwargs
)
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
tmp
=
reg_branches
[
lid
](
output
)
new_reference_points
=
tmp
new_reference_points
=
tmp
new_reference_points
[...,
:
self
.
kp_coord_dim
]
=
tmp
[
new_reference_points
[...,
:
self
.
kp_coord_dim
]
=
tmp
[
...,
:
self
.
kp_coord_dim
]
+
inverse_sigmoid
(
reference_points
)
...,
:
self
.
kp_coord_dim
]
+
inverse_sigmoid
(
reference_points
)
new_reference_points
=
new_reference_points
.
sigmoid
()
new_reference_points
=
new_reference_points
.
sigmoid
()
if
reference_points
.
shape
[
-
1
]
==
3
and
self
.
kp_coord_dim
==
2
:
if
reference_points
.
shape
[
-
1
]
==
3
and
self
.
kp_coord_dim
==
2
:
reference_points
[...,
-
1
]
=
tmp
[...,
-
1
].
sigmoid
().
detach
()
reference_points
[...,
-
1
]
=
tmp
[...,
-
1
].
sigmoid
().
detach
()
reference_points
[...,:
self
.
coord_dim
]
=
new_reference_points
.
detach
()
reference_points
[...,:
self
.
coord_dim
]
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
intermediate_reference_points
)
return
output
,
reference_points
return
output
,
reference_points
@
TRANSFORMER
.
register_module
()
@
TRANSFORMER
.
register_module
()
class
DeformableDetrTransformer_
(
Transformer
):
class
DeformableDetrTransformer_
(
Transformer
):
"""Implements the DeformableDETR transformer.
"""Implements the DeformableDETR transformer.
Args:
Args:
as_two_stage (bool): Generate query from encoder features.
as_two_stage (bool): Generate query from encoder features.
Default: False.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
`as_two_stage` as True. Default: 300.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
as_two_stage
=
False
,
as_two_stage
=
False
,
num_feature_levels
=
1
,
num_feature_levels
=
1
,
two_stage_num_proposals
=
300
,
two_stage_num_proposals
=
300
,
coord_dim
=
2
,
coord_dim
=
2
,
**
kwargs
):
**
kwargs
):
super
(
DeformableDetrTransformer_
,
self
).
__init__
(
**
kwargs
)
super
(
DeformableDetrTransformer_
,
self
).
__init__
(
**
kwargs
)
self
.
as_two_stage
=
as_two_stage
self
.
as_two_stage
=
as_two_stage
self
.
num_feature_levels
=
num_feature_levels
self
.
num_feature_levels
=
num_feature_levels
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
embed_dims
=
self
.
encoder
.
embed_dims
self
.
embed_dims
=
self
.
encoder
.
embed_dims
self
.
coord_dim
=
coord_dim
self
.
coord_dim
=
coord_dim
self
.
init_layers
()
self
.
init_layers
()
def
init_layers
(
self
):
def
init_layers
(
self
):
"""Initialize layers of the DeformableDetrTransformer."""
"""Initialize layers of the DeformableDetrTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
if
self
.
as_two_stage
:
if
self
.
as_two_stage
:
self
.
enc_output
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
)
self
.
enc_output
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
)
self
.
pos_trans
=
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
pos_trans
=
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
*
2
)
self
.
embed_dims
*
2
)
self
.
pos_trans_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
*
2
)
self
.
pos_trans_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
*
2
)
else
:
else
:
self
.
reference_points_embed
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
)
self
.
reference_points_embed
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
coord_dim
)
def
init_weights
(
self
):
def
init_weights
(
self
):
"""Initialize the transformer weights."""
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MultiScaleDeformableAttention
):
if
isinstance
(
m
,
MultiScaleDeformableAttention
):
m
.
init_weights
()
m
.
init_weights
()
elif
isinstance
(
m
,
MultiScaleDeformableAttentionFp16
):
elif
isinstance
(
m
,
MultiScaleDeformableAttentionFp16
):
m
.
init_weights
()
m
.
init_weights
()
if
not
self
.
as_two_stage
:
if
not
self
.
as_two_stage
:
xavier_init
(
self
.
reference_points_embed
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
reference_points_embed
,
distribution
=
'uniform'
,
bias
=
0.
)
normal_
(
self
.
level_embeds
)
normal_
(
self
.
level_embeds
)
@
staticmethod
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
"""Get the reference points used in decoder.
"""Get the reference points used in decoder.
Args:
Args:
spatial_shapes (Tensor): The shape of all
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
points on the feature map, has shape
(bs, num_levels, 2)
(bs, num_levels, 2)
device (obj:`device`): The device where
device (obj:`device`): The device where
reference_points should be.
reference_points should be.
Returns:
Returns:
Tensor: reference points used in decoder, has \
Tensor: reference points used in decoder, has
\
shape (bs, num_keys, num_levels, 2).
shape (bs, num_keys, num_levels, 2).
"""
"""
reference_points_list
=
[]
reference_points_list
=
[]
for
lvl
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
):
for
lvl
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
):
# TODO check this 0.5
# TODO check this 0.5
ref_y
,
ref_x
=
torch
.
meshgrid
(
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
torch
.
float32
,
device
=
device
),
0.5
,
H
-
0.5
,
H
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
torch
.
float32
,
device
=
device
))
0.5
,
W
-
0.5
,
W
,
dtype
=
torch
.
float32
,
device
=
device
))
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
1
]
*
H
)
valid_ratios
[:,
None
,
lvl
,
1
]
*
H
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
0
]
*
W
)
valid_ratios
[:,
None
,
lvl
,
0
]
*
W
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
return
reference_points
return
reference_points
def
get_valid_ratio
(
self
,
mask
):
def
get_valid_ratio
(
self
,
mask
):
"""Get the valid radios of feature maps of all level."""
"""Get the valid radios of feature maps of all level."""
_
,
H
,
W
=
mask
.
shape
_
,
H
,
W
=
mask
.
shape
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
return
valid_ratio
def
get_proposal_pos_embed
(
self
,
def
get_proposal_pos_embed
(
self
,
proposals
,
proposals
,
num_pos_feats
=
128
,
num_pos_feats
=
128
,
temperature
=
10000
):
temperature
=
10000
):
"""Get the position embedding of proposal."""
"""Get the position embedding of proposal."""
scale
=
2
*
math
.
pi
scale
=
2
*
math
.
pi
dim_t
=
torch
.
arange
(
dim_t
=
torch
.
arange
(
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
proposals
.
device
)
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
proposals
.
device
)
dim_t
=
temperature
**
(
2
*
(
dim_t
//
2
)
/
num_pos_feats
)
dim_t
=
temperature
**
(
2
*
(
dim_t
//
2
)
/
num_pos_feats
)
# N, L, 4
# N, L, 4
proposals
=
proposals
.
sigmoid
()
*
scale
proposals
=
proposals
.
sigmoid
()
*
scale
# N, L, 4, 128
# N, L, 4, 128
pos
=
proposals
[:,
:,
:,
None
]
/
dim_t
pos
=
proposals
[:,
:,
:,
None
]
/
dim_t
# N, L, 4, 64, 2
# N, L, 4, 64, 2
pos
=
torch
.
stack
((
pos
[:,
:,
:,
0
::
2
].
sin
(),
pos
[:,
:,
:,
1
::
2
].
cos
()),
pos
=
torch
.
stack
((
pos
[:,
:,
:,
0
::
2
].
sin
(),
pos
[:,
:,
:,
1
::
2
].
cos
()),
dim
=
4
).
flatten
(
2
)
dim
=
4
).
flatten
(
2
)
return
pos
return
pos
def
forward
(
self
,
def
forward
(
self
,
mlvl_feats
,
mlvl_feats
,
mlvl_masks
,
mlvl_masks
,
query_embed
,
query_embed
,
mlvl_pos_embeds
,
mlvl_pos_embeds
,
reg_branches
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
cls_branches
=
None
,
**
kwargs
):
**
kwargs
):
"""Forward function for `Transformer`.
"""Forward function for `Transformer`.
Args:
Args:
mlvl_feats (list(Tensor)): Input queries from
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
different level. Each element has shape
[bs, embed_dims, h, w].
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
different level used for encoder and decoder,
each element has shape [bs, h, w].
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
of feats from different level, has the shape
[bs, embed_dims, h, w].
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
feature maps from each decoder layer. Only would
be passed when
be passed when
`with_box_refine` is True. Default to None.
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
be passed when `as_two_stage`
is True. Default to None.
is True. Default to None.
Returns:
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
return_intermediate_dec is True output has shape
\
(num_dec_layers, bs, num_query, embed_dims), else has \
(num_dec_layers, bs, num_query, embed_dims), else has
\
shape (1, bs, num_query, embed_dims).
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
- init_reference_out: The initial value of reference
\
points, has shape (bs, num_queries, 4).
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
- inter_references_out: The internal value of reference
\
points in decoder, has shape \
points in decoder, has shape
\
(num_dec_layers, bs,num_query, embed_dims)
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
- enc_outputs_class: The classification score of
\
proposals generated from \
proposals generated from
\
encoder's feature maps, has shape \
encoder's feature maps, has shape
\
(batch, h*w, num_classes). \
(batch, h*w, num_classes).
\
Only would be returned when `as_two_stage` is True, \
Only would be returned when `as_two_stage` is True,
\
otherwise None.
otherwise None.
- enc_outputs_coord_unact: The regression results \
- enc_outputs_coord_unact: The regression results
\
generated from encoder's feature maps., has shape \
generated from encoder's feature maps., has shape
\
(batch, h*w, 4). Only would \
(batch, h*w, 4). Only would
\
be returned when `as_two_stage` is True, \
be returned when `as_two_stage` is True,
\
otherwise None.
otherwise None.
"""
"""
assert
self
.
as_two_stage
or
query_embed
is
not
None
assert
self
.
as_two_stage
or
query_embed
is
not
None
feat_flatten
=
[]
feat_flatten
=
[]
mask_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
spatial_shapes
=
[]
for
lvl
,
(
feat
,
mask
,
pos_embed
)
in
enumerate
(
for
lvl
,
(
feat
,
mask
,
pos_embed
)
in
enumerate
(
zip
(
mlvl_feats
,
mlvl_masks
,
mlvl_pos_embeds
)):
zip
(
mlvl_feats
,
mlvl_masks
,
mlvl_pos_embeds
)):
bs
,
c
,
h
,
w
=
feat
.
shape
bs
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
spatial_shape
=
(
h
,
w
)
spatial_shapes
.
append
(
spatial_shape
)
spatial_shapes
.
append
(
spatial_shape
)
feat
=
feat
.
flatten
(
2
).
transpose
(
1
,
2
)
feat
=
feat
.
flatten
(
2
).
transpose
(
1
,
2
)
mask
=
mask
.
flatten
(
1
)
mask
=
mask
.
flatten
(
1
)
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
lvl_pos_embed
=
pos_embed
+
self
.
level_embeds
[
lvl
].
view
(
1
,
1
,
-
1
)
lvl_pos_embed
=
pos_embed
+
self
.
level_embeds
[
lvl
].
view
(
1
,
1
,
-
1
)
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
feat_flatten
.
append
(
feat
)
feat_flatten
.
append
(
feat
)
mask_flatten
.
append
(
mask
)
mask_flatten
.
append
(
mask
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
1
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
1
)
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
feat_flatten
.
device
)
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
feat_flatten
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,
)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
(
1
,
)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
valid_ratios
=
torch
.
stack
(
valid_ratios
=
torch
.
stack
(
[
self
.
get_valid_ratio
(
m
)
for
m
in
mlvl_masks
],
1
)
[
self
.
get_valid_ratio
(
m
)
for
m
in
mlvl_masks
],
1
)
# reference_points = \
# reference_points = \
# self.get_reference_points(spatial_shapes,
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# valid_ratios,
# device=feat.device)
# device=feat.device)
feat_flatten
=
feat_flatten
.
permute
(
1
,
0
,
2
)
# (H*W, bs, embed_dims)
feat_flatten
=
feat_flatten
.
permute
(
1
,
0
,
2
)
# (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# memory = self.encoder(
# query=feat_flatten,
# query=feat_flatten,
# key=None,
# key=None,
# value=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# reference_points=reference_points,
# level_start_index=level_start_index,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# valid_ratios=valid_ratios,
# **kwargs)
# **kwargs)
memory
=
feat_flatten
.
permute
(
1
,
0
,
2
)
memory
=
feat_flatten
.
permute
(
1
,
0
,
2
)
bs
,
_
,
c
=
memory
.
shape
bs
,
_
,
c
=
memory
.
shape
query_pos
,
query
=
torch
.
split
(
query_embed
,
c
,
dim
=-
1
)
query_pos
,
query
=
torch
.
split
(
query_embed
,
c
,
dim
=-
1
)
reference_points
=
self
.
reference_points_embed
(
query_pos
).
sigmoid
()
reference_points
=
self
.
reference_points_embed
(
query_pos
).
sigmoid
()
init_reference_out
=
reference_points
init_reference_out
=
reference_points
# decoder
# decoder
query
=
query
.
permute
(
1
,
0
,
2
)
query
=
query
.
permute
(
1
,
0
,
2
)
memory
=
memory
.
permute
(
1
,
0
,
2
)
memory
=
memory
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
query
=
query
,
key
=
None
,
key
=
None
,
value
=
memory
,
value
=
memory
,
query_pos
=
query_pos
,
query_pos
=
query_pos
,
key_padding_mask
=
mask_flatten
,
key_padding_mask
=
mask_flatten
,
reference_points
=
reference_points
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
level_start_index
=
level_start_index
,
valid_ratios
=
valid_ratios
,
valid_ratios
=
valid_ratios
,
reg_branches
=
reg_branches
,
reg_branches
=
reg_branches
,
**
kwargs
)
**
kwargs
)
inter_references_out
=
inter_references
inter_references_out
=
inter_references
return
inter_states
,
init_reference_out
,
inter_references_out
return
inter_states
,
init_reference_out
,
inter_references_out
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/src/models/transformer_utils/fp16_dattn.py
→
autonomous_driving/Online-HD-Map-Construction/src/models/transformer_utils/fp16_dattn.py
View file @
f3b13cad
from
turtle
import
forward
from
turtle
import
forward
import
warnings
import
warnings
try
:
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
except
ImportError
:
except
ImportError
:
warnings
.
warn
(
warnings
.
warn
(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV'
)
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
mmcv.cnn.bricks.transformer
import
MultiScaleDeformableAttention
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.transformer
import
build_attention
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
import
math
import
warnings
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.autograd.function
import
Function
,
once_differentiable
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv
import
deprecated_api_warning
from
mmcv
import
deprecated_api_warning
from
mmcv.cnn
import
constant_init
,
xavier_init
from
mmcv.cnn
import
constant_init
,
xavier_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmcv.utils
import
ext_loader
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
@
ATTENTION
.
register_module
()
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFp16
(
BaseModule
):
class
MultiScaleDeformableAttentionFp16
(
BaseModule
):
def
__init__
(
self
,
attn_cfg
=
None
,
init_cfg
=
None
,
**
kwarg
):
def
__init__
(
self
,
attn_cfg
=
None
,
init_cfg
=
None
,
**
kwarg
):
super
(
MultiScaleDeformableAttentionFp16
,
self
).
__init__
(
init_cfg
)
super
(
MultiScaleDeformableAttentionFp16
,
self
).
__init__
(
init_cfg
)
# import ipdb; ipdb.set_trace()
# import ipdb; ipdb.set_trace()
self
.
deformable_attention
=
build_attention
(
attn_cfg
)
self
.
deformable_attention
=
build_attention
(
attn_cfg
)
self
.
deformable_attention
.
init_weights
()
self
.
deformable_attention
.
init_weights
()
self
.
fp16_enabled
=
False
self
.
fp16_enabled
=
False
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points'
,
'identity'
))
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points'
,
'identity'
))
def
forward
(
self
,
query
,
def
forward
(
self
,
query
,
key
=
None
,
key
=
None
,
value
=
None
,
value
=
None
,
identity
=
None
,
identity
=
None
,
query_pos
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
level_start_index
=
None
,
**
kwargs
):
**
kwargs
):
# import ipdb; ipdb.set_trace()
# import ipdb; ipdb.set_trace()
return
self
.
deformable_attention
(
query
,
return
self
.
deformable_attention
(
query
,
key
=
key
,
key
=
key
,
value
=
value
,
value
=
value
,
identity
=
identity
,
identity
=
identity
,
query_pos
=
query_pos
,
query_pos
=
query_pos
,
key_padding_mask
=
key_padding_mask
,
key_padding_mask
=
key_padding_mask
,
reference_points
=
reference_points
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
level_start_index
=
level_start_index
,
**
kwargs
)
class
MultiScaleDeformableAttnFunctionFp32
(
Function
):
class
MultiScaleDeformableAttnFunctionFp32
(
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
"""GPU version of multi-scale deformable attention.
Args:
Args:
value (Tensor): The value has shape
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
sampling_locations (Tensor): The location of sampling points,
has shape
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
im2col_step (Tensor): The step used in image to column.
Returns:
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
Tensor: has shape (bs, num_queries, embed_dims)
"""
"""
ctx
.
im2col_step
=
im2col_step
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value
,
value_spatial_shapes
,
value_spatial_shapes
,
value_level_start_index
,
value_level_start_index
,
sampling_locations
,
sampling_locations
,
attention_weights
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
attention_weights
)
return
output
return
output
@
staticmethod
@
staticmethod
@
once_differentiable
@
once_differentiable
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
"""GPU version of backward function.
Args:
Args:
grad_output (Tensor): Gradient
grad_output (Tensor): Gradient
of output tensor of forward.
of output tensor of forward.
Returns:
Returns:
Tuple[Tensor]: Gradient
Tuple[Tensor]: Gradient
of input tensors in forward.
of input tensors in forward.
"""
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
ext_module
.
ms_deform_attn_backward
(
value
,
value
,
value_spatial_shapes
,
value_spatial_shapes
,
value_level_start_index
,
value_level_start_index
,
sampling_locations
,
sampling_locations
,
attention_weights
,
attention_weights
,
grad_output
.
contiguous
(),
grad_output
.
contiguous
(),
grad_value
,
grad_value
,
grad_sampling_loc
,
grad_sampling_loc
,
grad_attn_weight
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
grad_sampling_loc
,
grad_attn_weight
,
None
def
multi_scale_deformable_attn_pytorch
(
value
,
value_spatial_shapes
,
def
multi_scale_deformable_attn_pytorch
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
):
sampling_locations
,
attention_weights
):
"""CPU version of multi-scale deformable attention.
"""CPU version of multi-scale deformable attention.
Args:
Args:
value (Tensor): The value has shape
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
sampling_locations (Tensor): The location of sampling points,
has shape
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
Tensor: has shape (bs, num_queries, embed_dims)
"""
"""
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
\
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
\
sampling_locations
.
shape
sampling_locations
.
shape
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
dim
=
1
)
dim
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
sampling_value_list
=
[]
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
# bs*num_heads, embed_dims, H_, W_
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
# bs*num_heads, embed_dims, num_queries, num_points
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_
=
F
.
grid_sample
(
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
value_l_
,
sampling_grid_l_
,
sampling_grid_l_
,
mode
=
'bilinear'
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
output
=
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
output
=
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
attention_weights
).
sum
(
-
1
).
view
(
bs
,
num_heads
*
embed_dims
,
attention_weights
).
sum
(
-
1
).
view
(
bs
,
num_heads
*
embed_dims
,
num_queries
)
num_queries
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
@
ATTENTION
.
register_module
()
@
ATTENTION
.
register_module
()
class
MultiScaleDeformableAttentionFP32
(
BaseModule
):
class
MultiScaleDeformableAttentionFP32
(
BaseModule
):
"""An attention module used in Deformable-Detr. `Deformable DETR:
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
Args:
embed_dims (int): The embedding dimension of Attention.
embed_dims (int): The embedding dimension of Attention.
Default: 256.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
num_levels (int): The number of feature map used in
Attention. Default: 4.
Attention. Default: 4.
num_points (int): The number of sampling points for
num_points (int): The number of sampling points for
each query in each head. Default: 4.
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
im2col_step (int): The step used in image_to_column.
Default: 64.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
Default: None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
embed_dims
=
256
,
embed_dims
=
256
,
num_heads
=
8
,
num_heads
=
8
,
num_levels
=
4
,
num_levels
=
4
,
num_points
=
4
,
num_points
=
4
,
im2col_step
=
64
,
im2col_step
=
64
,
dropout
=
0.1
,
dropout
=
0.1
,
batch_first
=
False
,
batch_first
=
False
,
norm_cfg
=
None
,
norm_cfg
=
None
,
init_cfg
=
None
):
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
batch_first
=
batch_first
# you'd better set dim_per_head to a power of 2
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
warnings
.
warn
(
"You'd better set embed_dims in "
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
self
.
init_weights
()
def
init_weights
(
self
):
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
thetas
=
torch
.
arange
(
self
.
num_heads
,
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
def
forward
(
self
,
query
,
query
,
key
=
None
,
key
=
None
,
value
=
None
,
value
=
None
,
identity
=
None
,
identity
=
None
,
query_pos
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
level_start_index
=
None
,
**
kwargs
):
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
"""Forward Function of MultiScaleDeformAttention.
Args:
Args:
query (Tensor): Query of Transformer with shape
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
same shape as `query`. Default None. If None,
`query` will be used.
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
key_pos (Tensor): The positional encoding for `key`. Default
None.
None.
reference_points (Tensor): The normalized reference
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
additional two dimensions is (w, h) to
form reference boxes.
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
"""
if
value
is
None
:
if
value
is
None
:
value
=
query
value
=
query
if
identity
is
None
:
if
identity
is
None
:
identity
=
query
identity
=
query
if
query_pos
is
not
None
:
if
query_pos
is
not
None
:
query
=
query
+
query_pos
query
=
query
+
query_pos
if
not
self
.
batch_first
:
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_levels
,
self
.
num_points
)
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
*
0.5
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
'Last dim of reference_points must be'
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
output
=
MultiScaleDeformableAttnFunctionFp32
.
apply
(
output
=
MultiScaleDeformableAttnFunctionFp32
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
attention_weights
,
self
.
im2col_step
)
else
:
else
:
output
=
multi_scale_deformable_attn_pytorch
(
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
attention_weights
,
self
.
im2col_step
)
output
=
self
.
output_proj
(
output
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
return
self
.
dropout
(
output
)
+
identity
\ No newline at end of file
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/dist_test.sh
→
autonomous_driving/Online-HD-Map-Construction/tools/dist_test.sh
View file @
f3b13cad
#!/usr/bin/env bash
#!/usr/bin/env bash
CONFIG
=
$1
CONFIG
=
$1
CHECKPOINT
=
$2
CHECKPOINT
=
$2
GPUS
=
$3
GPUS
=
$3
PORT
=
${
PORT
:-
29500
}
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/dist_train.sh
→
autonomous_driving/Online-HD-Map-Construction/tools/dist_train.sh
View file @
f3b13cad
#!/usr/bin/env bash
#!/usr/bin/env bash
CONFIG
=
$1
CONFIG
=
$1
GPUS
=
$2
GPUS
=
$2
PORT
=
${
PORT
:-
29500
}
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/train.py
$CONFIG
--launcher
pytorch
${
@
:3
}
$(
dirname
"
$0
"
)
/train.py
$CONFIG
--launcher
pytorch
${
@
:3
}
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/evaluate_submission.py
→
autonomous_driving/Online-HD-Map-Construction/tools/evaluate_submission.py
View file @
f3b13cad
import
sys
import
sys
import
os
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
from
src.datasets.evaluation.vector_eval
import
VectorEvaluate
from
src.datasets.evaluation.vector_eval
import
VectorEvaluate
import
argparse
import
argparse
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a submission file'
)
description
=
'Evaluate a submission file'
)
parser
.
add_argument
(
'submission'
,
parser
.
add_argument
(
'submission'
,
help
=
'submission file in pickle or json format to be evaluated'
)
help
=
'submission file in pickle or json format to be evaluated'
)
parser
.
add_argument
(
'gt'
,
parser
.
add_argument
(
'gt'
,
help
=
'gt annotation file'
)
help
=
'gt annotation file'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
main
(
args
):
def
main
(
args
):
evaluator
=
VectorEvaluate
(
args
.
gt
,
n_workers
=
0
)
evaluator
=
VectorEvaluate
(
args
.
gt
,
n_workers
=
0
)
results
=
evaluator
.
evaluate
(
args
.
submission
)
results
=
evaluator
.
evaluate
(
args
.
submission
)
print
(
results
)
print
(
results
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
main
(
args
)
main
(
args
)
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/mmdet_test.py
→
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_test.py
View file @
f3b13cad
import
os.path
as
osp
import
os.path
as
osp
import
pickle
import
pickle
import
shutil
import
shutil
import
tempfile
import
tempfile
import
time
import
time
import
mmcv
import
mmcv
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
mmcv.image
import
tensor2imgs
from
mmcv.image
import
tensor2imgs
from
mmcv.runner
import
get_dist_info
from
mmcv.runner
import
get_dist_info
from
mmdet.core
import
encode_mask_results
from
mmdet.core
import
encode_mask_results
def
single_gpu_test
(
model
,
def
single_gpu_test
(
model
,
data_loader
,
data_loader
,
show
=
False
,
show
=
False
,
out_dir
=
None
,
out_dir
=
None
,
show_score_thr
=
0.3
):
show_score_thr
=
0.3
):
model
.
eval
()
model
.
eval
()
results
=
[]
results
=
[]
dataset
=
data_loader
.
dataset
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
batch_size
=
len
(
result
)
batch_size
=
len
(
result
)
if
show
or
out_dir
:
if
show
or
out_dir
:
if
batch_size
==
1
and
isinstance
(
data
[
'img'
][
0
],
torch
.
Tensor
):
if
batch_size
==
1
and
isinstance
(
data
[
'img'
][
0
],
torch
.
Tensor
):
img_tensor
=
data
[
'img'
][
0
]
img_tensor
=
data
[
'img'
][
0
]
else
:
else
:
img_tensor
=
data
[
'img'
][
0
].
data
[
0
]
img_tensor
=
data
[
'img'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
imgs
=
tensor2imgs
(
img_tensor
,
**
img_metas
[
0
][
'img_norm_cfg'
])
imgs
=
tensor2imgs
(
img_tensor
,
**
img_metas
[
0
][
'img_norm_cfg'
])
assert
len
(
imgs
)
==
len
(
img_metas
)
assert
len
(
imgs
)
==
len
(
img_metas
)
for
i
,
(
img
,
img_meta
)
in
enumerate
(
zip
(
imgs
,
img_metas
)):
for
i
,
(
img
,
img_meta
)
in
enumerate
(
zip
(
imgs
,
img_metas
)):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
h
,
w
,
_
=
img_meta
[
'img_shape'
]
img_show
=
img
[:
h
,
:
w
,
:]
img_show
=
img
[:
h
,
:
w
,
:]
ori_h
,
ori_w
=
img_meta
[
'ori_shape'
][:
-
1
]
ori_h
,
ori_w
=
img_meta
[
'ori_shape'
][:
-
1
]
img_show
=
mmcv
.
imresize
(
img_show
,
(
ori_w
,
ori_h
))
img_show
=
mmcv
.
imresize
(
img_show
,
(
ori_w
,
ori_h
))
if
out_dir
:
if
out_dir
:
out_file
=
osp
.
join
(
out_dir
,
img_meta
[
'ori_filename'
])
out_file
=
osp
.
join
(
out_dir
,
img_meta
[
'ori_filename'
])
else
:
else
:
out_file
=
None
out_file
=
None
model
.
module
.
show_result
(
model
.
module
.
show_result
(
img_show
,
img_show
,
result
[
i
],
result
[
i
],
show
=
show
,
show
=
show
,
out_file
=
out_file
,
out_file
=
out_file
,
score_thr
=
show_score_thr
)
score_thr
=
show_score_thr
)
# encode mask results
# encode mask results
if
isinstance
(
result
[
0
],
tuple
):
if
isinstance
(
result
[
0
],
tuple
):
result
=
[(
bbox_results
,
encode_mask_results
(
mask_results
))
result
=
[(
bbox_results
,
encode_mask_results
(
mask_results
))
for
bbox_results
,
mask_results
in
result
]
for
bbox_results
,
mask_results
in
result
]
results
.
extend
(
result
)
results
.
extend
(
result
)
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
prog_bar
.
update
()
prog_bar
.
update
()
return
results
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
"""Test model with multiple gpus.
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
and collects them by the rank 0 worker.
Args:
Args:
model (nn.Module): Model to be tested.
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
Returns:
list: The prediction results.
list: The prediction results.
"""
"""
model
.
eval
()
model
.
eval
()
results
=
[]
results
=
[]
dataset
=
data_loader
.
dataset
dataset
=
data_loader
.
dataset
rank
,
world_size
=
get_dist_info
()
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
time
.
sleep
(
2
)
# This line can prevent deadlock problem in some cases.
time
.
sleep
(
2
)
# This line can prevent deadlock problem in some cases.
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
# encode mask results
# encode mask results
# if isinstance(result[0], tuple):
# if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results))
# result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result]
# for bbox_results, mask_results in result]
results
.
extend
(
result
)
results
.
extend
(
result
)
if
rank
==
0
:
if
rank
==
0
:
batch_size
=
len
(
result
)
batch_size
=
len
(
result
)
for
_
in
range
(
batch_size
*
world_size
):
for
_
in
range
(
batch_size
*
world_size
):
prog_bar
.
update
()
prog_bar
.
update
()
# collect results from all ranks
# collect results from all ranks
if
gpu_collect
:
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
return
results
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
rank
,
world_size
=
get_dist_info
()
rank
,
world_size
=
get_dist_info
()
# create a tmp dir if it is not specified
# create a tmp dir if it is not specified
if
tmpdir
is
None
:
if
tmpdir
is
None
:
MAX_LEN
=
512
MAX_LEN
=
512
# 32 is whitespace
# 32 is whitespace
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
32
,
32
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
device
=
'cuda'
)
if
rank
==
0
:
if
rank
==
0
:
mmcv
.
mkdir_or_exist
(
'.dist_test'
)
mmcv
.
mkdir_or_exist
(
'.dist_test'
)
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
'.dist_test'
)
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
'.dist_test'
)
tmpdir
=
torch
.
tensor
(
tmpdir
=
torch
.
tensor
(
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dist
.
broadcast
(
dir_tensor
,
0
)
dist
.
broadcast
(
dir_tensor
,
0
)
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
else
:
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
dist
.
barrier
()
dist
.
barrier
()
# collect all parts
# collect all parts
if
rank
!=
0
:
if
rank
!=
0
:
return
None
return
None
else
:
else
:
# load results of all parts from tmp dir
# load results of all parts from tmp dir
part_list
=
[]
part_list
=
[]
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_list
.
append
(
mmcv
.
load
(
part_file
))
part_list
.
append
(
mmcv
.
load
(
part_file
))
# sort the results
# sort the results
ordered_results
=
[]
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
):
def
collect_results_gpu
(
result_part
,
size
):
rank
,
world_size
=
get_dist_info
()
rank
,
world_size
=
get_dist_info
()
# dump result part to tensor with pickle
# dump result part to tensor with pickle
part_tensor
=
torch
.
tensor
(
part_tensor
=
torch
.
tensor
(
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# gather all result part tensor shape
# gather all result part tensor shape
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
shape_list
,
shape_tensor
)
dist
.
all_gather
(
shape_list
,
shape_tensor
)
# padding result part tensor to max length
# padding result part tensor to max length
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_recv_list
=
[
part_recv_list
=
[
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
]
]
# gather all result part
# gather all result part
dist
.
all_gather
(
part_recv_list
,
part_send
)
dist
.
all_gather
(
part_recv_list
,
part_send
)
if
rank
==
0
:
if
rank
==
0
:
part_list
=
[]
part_list
=
[]
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
part_list
.
append
(
part_list
.
append
(
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
()))
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
()))
# sort the results
# sort the results
ordered_results
=
[]
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
return
ordered_results
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/mmdet_train.py
→
autonomous_driving/Online-HD-Map-Construction/tools/mmdet_train.py
View file @
f3b13cad
import
random
import
random
import
warnings
import
warnings
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
HOOKS
,
DistSamplerSeedHook
,
EpochBasedRunner
,
from
mmcv.runner
import
(
HOOKS
,
DistSamplerSeedHook
,
EpochBasedRunner
,
Fp16OptimizerHook
,
OptimizerHook
,
build_optimizer
,
Fp16OptimizerHook
,
OptimizerHook
,
build_optimizer
,
build_runner
)
build_runner
)
from
mmcv.utils
import
build_from_cfg
from
mmcv.utils
import
build_from_cfg
from
mmdet.core
import
DistEvalHook
,
EvalHook
from
mmdet.core
import
DistEvalHook
,
EvalHook
from
mmdet.datasets
import
(
build_dataloader
,
build_dataset
,
from
mmdet.datasets
import
(
build_dataloader
,
build_dataset
,
replace_ImageToTensor
)
replace_ImageToTensor
)
from
mmdet.utils
import
get_root_logger
from
mmdet.utils
import
get_root_logger
def
set_random_seed
(
seed
,
deterministic
=
False
):
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
"""Set random seed.
Args:
Args:
seed (int): Seed to be used.
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
Default: False.
"""
"""
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
def
train_detector
(
model
,
def
train_detector
(
model
,
dataset
,
dataset
,
cfg
,
cfg
,
distributed
=
False
,
distributed
=
False
,
validate
=
False
,
validate
=
False
,
timestamp
=
None
,
timestamp
=
None
,
meta
=
None
):
meta
=
None
):
logger
=
get_root_logger
(
cfg
.
log_level
)
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
if
'imgs_per_gpu'
in
cfg
.
data
:
if
'imgs_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
'"imgs_per_gpu" is deprecated in MMDet V2.0. '
logger
.
warning
(
'"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead'
)
'Please use "samples_per_gpu" instead'
)
if
'samples_per_gpu'
in
cfg
.
data
:
if
'samples_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
logger
.
warning
(
f
'Got "imgs_per_gpu"=
{
cfg
.
data
.
imgs_per_gpu
}
and '
f
'Got "imgs_per_gpu"=
{
cfg
.
data
.
imgs_per_gpu
}
and '
f
'"samples_per_gpu"=
{
cfg
.
data
.
samples_per_gpu
}
, "imgs_per_gpu"'
f
'"samples_per_gpu"=
{
cfg
.
data
.
samples_per_gpu
}
, "imgs_per_gpu"'
f
'=
{
cfg
.
data
.
imgs_per_gpu
}
is used in this experiments'
)
f
'=
{
cfg
.
data
.
imgs_per_gpu
}
is used in this experiments'
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f
'
{
cfg
.
data
.
imgs_per_gpu
}
in this experiments'
)
f
'
{
cfg
.
data
.
imgs_per_gpu
}
in this experiments'
)
cfg
.
data
.
samples_per_gpu
=
cfg
.
data
.
imgs_per_gpu
cfg
.
data
.
samples_per_gpu
=
cfg
.
data
.
imgs_per_gpu
data_loaders
=
[
data_loaders
=
[
build_dataloader
(
build_dataloader
(
ds
,
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
dist
=
distributed
,
seed
=
cfg
.
seed
)
for
ds
in
dataset
seed
=
cfg
.
seed
)
for
ds
in
dataset
]
]
# put model on gpus
# put model on gpus
if
distributed
:
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
find_unused_parameters
=
find_unused_parameters
)
else
:
else
:
model
=
MMDataParallel
(
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
'runner'
not
in
cfg
:
if
'runner'
not
in
cfg
:
cfg
.
runner
=
{
cfg
.
runner
=
{
'type'
:
'EpochBasedRunner'
,
'type'
:
'EpochBasedRunner'
,
'max_epochs'
:
cfg
.
total_epochs
'max_epochs'
:
cfg
.
total_epochs
}
}
warnings
.
warn
(
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
'please set `runner` in your config.'
,
UserWarning
)
else
:
else
:
if
'total_epochs'
in
cfg
:
if
'total_epochs'
in
cfg
:
assert
cfg
.
total_epochs
==
cfg
.
runner
.
max_epochs
assert
cfg
.
total_epochs
==
cfg
.
runner
.
max_epochs
runner
=
build_runner
(
runner
=
build_runner
(
cfg
.
runner
,
cfg
.
runner
,
default_args
=
dict
(
default_args
=
dict
(
model
=
model
,
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
logger
=
logger
,
meta
=
meta
))
meta
=
meta
))
# an ugly workaround to make .log and .log.json filenames the same
# an ugly workaround to make .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
runner
.
timestamp
=
timestamp
# fp16 setting
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
else
:
optimizer_config
=
cfg
.
optimizer_config
optimizer_config
=
cfg
.
optimizer_config
# register hooks
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
cfg
.
get
(
'momentum_config'
,
None
))
if
distributed
:
if
distributed
:
if
isinstance
(
runner
,
EpochBasedRunner
):
if
isinstance
(
runner
,
EpochBasedRunner
):
runner
.
register_hook
(
DistSamplerSeedHook
())
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
# register eval hooks
if
validate
:
if
validate
:
# Support batch_size > 1 in validation
# Support batch_size > 1 in validation
val_samples_per_gpu
=
cfg
.
data
.
val
.
pop
(
'samples_per_gpu'
,
1
)
val_samples_per_gpu
=
cfg
.
data
.
val
.
pop
(
'samples_per_gpu'
,
1
)
if
val_samples_per_gpu
>
1
:
if
val_samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
val
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
val
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
val
.
pipeline
)
cfg
.
data
.
val
.
pipeline
)
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_dataloader
(
val_dataloader
=
build_dataloader
(
val_dataset
,
val_dataset
,
samples_per_gpu
=
val_samples_per_gpu
,
samples_per_gpu
=
val_samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
dist
=
distributed
,
shuffle
=
False
)
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
))
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
))
# user-defined hooks
# user-defined hooks
if
cfg
.
get
(
'custom_hooks'
,
None
):
if
cfg
.
get
(
'custom_hooks'
,
None
):
custom_hooks
=
cfg
.
custom_hooks
custom_hooks
=
cfg
.
custom_hooks
assert
isinstance
(
custom_hooks
,
list
),
\
assert
isinstance
(
custom_hooks
,
list
),
\
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
for
hook_cfg
in
cfg
.
custom_hooks
:
for
hook_cfg
in
cfg
.
custom_hooks
:
assert
isinstance
(
hook_cfg
,
dict
),
\
assert
isinstance
(
hook_cfg
,
dict
),
\
'Each item in custom_hooks expects dict type, but got '
\
'Each item in custom_hooks expects dict type, but got '
\
f
'
{
type
(
hook_cfg
)
}
'
f
'
{
type
(
hook_cfg
)
}
'
hook_cfg
=
hook_cfg
.
copy
()
hook_cfg
=
hook_cfg
.
copy
()
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
if
cfg
.
resume_from
:
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/test.py
→
autonomous_driving/Online-HD-Map-Construction/tools/test.py
View file @
f3b13cad
import
argparse
import
argparse
import
mmcv
import
mmcv
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
torch
import
torch
import
warnings
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
wrap_fp16_model
)
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataloader
,
build_dataset
from
mmdet3d.datasets
import
build_dataloader
,
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet3d.models
import
build_model
from
mmdet_test
import
multi_gpu_test
from
mmdet_test
import
multi_gpu_test
from
mmdet_train
import
set_random_seed
from
mmdet_train
import
set_random_seed
from
mmdet.datasets
import
replace_ImageToTensor
from
mmdet.datasets
import
replace_ImageToTensor
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet test (and eval) a model'
)
description
=
'MMDet test (and eval) a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
type
=
str
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'checkpoint'
,
type
=
str
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
required
=
True
,
help
=
'which split to test on'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
required
=
True
,
help
=
'which split to test on'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--fuse-conv-bn'
,
'--fuse-conv-bn'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Whether to fuse conv and bn, this will slightly increase'
help
=
'Whether to fuse conv and bn, this will slightly increase'
'the inference speed'
)
'the inference speed'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--format-only'
,
'--format-only'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
'submit it to the test server'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--eval'
,
'--eval'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether to run evaluation.'
)
help
=
'whether to run evaluation.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--gpu-collect'
,
'--gpu-collect'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether to use gpu to collect results.'
)
help
=
'whether to use gpu to collect results.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--tmpdir'
,
'--tmpdir'
,
help
=
'tmp directory used for collecting results from multiple '
help
=
'tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified'
)
'workers, available when gpu-collect is not specified'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--deterministic'
,
'--deterministic'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--launcher'
,
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
default
=
'none'
,
help
=
'job launcher'
)
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
return
args
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
if
args
.
split
not
in
[
'val'
,
'test'
]:
if
args
.
split
not
in
[
'val'
,
'test'
]:
raise
ValueError
(
'Please choose "val" or "test" split for testing'
)
raise
ValueError
(
'Please choose "val" or "test" split for testing'
)
if
(
args
.
eval
and
args
.
format_only
)
or
(
not
args
.
eval
and
not
args
.
format_only
):
if
(
args
.
eval
and
args
.
format_only
)
or
(
not
args
.
eval
and
not
args
.
format_only
):
raise
ValueError
(
'Please specify exactly one operation (eval/format) '
raise
ValueError
(
'Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"'
)
'with the argument "--eval" or "--format-only"'
)
if
args
.
eval
and
args
.
split
==
'test'
:
if
args
.
eval
and
args
.
split
==
'test'
:
raise
ValueError
(
'Cannot evaluate on test set'
)
raise
ValueError
(
'Cannot evaluate on test set'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
# import modules from string list.
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules from plguin/xx, registry will be updated
# import modules from plguin/xx, registry will be updated
import
sys
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
if
cfg
.
plugin
:
import
importlib
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
import_path
(
plugin_dir
)
else
:
else
:
# import dir is the dirpath for the config file
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
cfg_data_dict
=
cfg
.
data
.
get
(
args
.
split
)
cfg_data_dict
=
cfg
.
data
.
get
(
args
.
split
)
cfg
.
model
.
pretrained
=
None
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
# in case the test dataset is concatenated
samples_per_gpu
=
1
samples_per_gpu
=
1
cfg_data_dict
.
test_mode
=
True
cfg_data_dict
.
test_mode
=
True
samples_per_gpu
=
cfg_data_dict
.
pop
(
'samples_per_gpu'
,
1
)
samples_per_gpu
=
cfg_data_dict
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict
.
pipeline
=
replace_ImageToTensor
(
cfg_data_dict
.
pipeline
=
replace_ImageToTensor
(
cfg_data_dict
.
pipeline
)
cfg_data_dict
.
pipeline
)
# init distributed env first, since logger depends on the dist info.
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
distributed
=
False
distributed
=
False
else
:
else
:
distributed
=
True
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# set random seeds
# set random seeds
if
args
.
seed
is
not
None
:
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the dataloader
# build the dataloader
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
cfg_data_dict
.
work_dir
=
cfg
.
work_dir
cfg_data_dict
.
work_dir
=
cfg
.
work_dir
print
(
'work_dir: '
,
cfg
.
work_dir
)
print
(
'work_dir: '
,
cfg
.
work_dir
)
dataset
=
build_dataset
(
cfg_data_dict
)
dataset
=
build_dataset
(
cfg_data_dict
)
data_loader
=
build_dataloader
(
data_loader
=
build_dataloader
(
dataset
,
dataset
,
samples_per_gpu
=
samples_per_gpu
,
samples_per_gpu
=
samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
dist
=
distributed
,
shuffle
=
False
)
shuffle
=
False
)
# build the model and load checkpoint
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
cfg
.
model
.
train_cfg
=
None
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
wrap_fp16_model
(
model
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
args
.
fuse_conv_bn
:
if
args
.
fuse_conv_bn
:
model
=
fuse_conv_bn
(
model
)
model
=
fuse_conv_bn
(
model
)
if
not
distributed
:
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
outputs
=
single_gpu_test
(
model
,
data_loader
)
outputs
=
single_gpu_test
(
model
,
data_loader
)
else
:
else
:
model
=
MMDistributedDataParallel
(
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
)
broadcast_buffers
=
False
)
outputs
=
multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
outputs
=
multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
args
.
gpu_collect
)
args
.
gpu_collect
)
rank
,
_
=
get_dist_info
()
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
rank
==
0
:
if
args
.
format_only
:
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
prefix
=
cfg
.
work_dir
)
dataset
.
format_results
(
outputs
,
prefix
=
cfg
.
work_dir
)
elif
args
.
eval
:
elif
args
.
eval
:
print
(
'start evaluation!'
)
print
(
'start evaluation!'
)
print
(
dataset
.
evaluate
(
outputs
))
print
(
dataset
.
evaluate
(
outputs
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/train.py
→
autonomous_driving/Online-HD-Map-Construction/tools/train.py
View file @
f3b13cad
from
__future__
import
division
from
__future__
import
division
import
argparse
import
argparse
import
copy
import
copy
import
mmcv
import
mmcv
import
os
import
os
import
time
import
time
import
torch
import
torch
import
warnings
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
mmcv.runner
import
get_dist_info
,
init_dist
from
os
import
path
as
osp
from
os
import
path
as
osp
from
mmdet
import
__version__
as
mmdet_version
from
mmdet
import
__version__
as
mmdet_version
from
mmdet3d
import
__version__
as
mmdet3d_version
from
mmdet3d
import
__version__
as
mmdet3d_version
from
mmdet3d.apis
import
train_model
from
mmdet3d.apis
import
train_model
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmseg
import
__version__
as
mmseg_version
from
mmseg
import
__version__
as
mmseg_version
# warper
# warper
from
mmdet_train
import
set_random_seed
from
mmdet_train
import
set_random_seed
# from builder import build_model
# from builder import build_model
from
mmdet3d.models
import
build_model
from
mmdet3d.models
import
build_model
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--no-validate'
,
'--no-validate'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
group_gpus
.
add_argument
(
'--gpus'
,
'--gpus'
,
type
=
int
,
type
=
int
,
help
=
'number of gpus to use '
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
group_gpus
.
add_argument
(
'--gpu-ids'
,
'--gpu-ids'
,
type
=
int
,
type
=
int
,
nargs
=
'+'
,
nargs
=
'+'
,
help
=
'ids of gpus to use '
help
=
'ids of gpus to use '
'(only applicable to non-distributed training)'
)
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--deterministic'
,
'--deterministic'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--options'
,
'--options'
,
nargs
=
'+'
,
nargs
=
'+'
,
action
=
DictAction
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.'
)
'change to --cfg-options instead.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--cfg-options'
,
'--cfg-options'
,
nargs
=
'+'
,
nargs
=
'+'
,
action
=
DictAction
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
'is allowed.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--launcher'
,
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
default
=
'none'
,
help
=
'job launcher'
)
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
parser
.
add_argument
(
'--autoscale-lr'
,
'--autoscale-lr'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'automatically scale lr with the number of gpus'
)
help
=
'automatically scale lr with the number of gpus'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
cfg_options
:
if
args
.
options
and
args
.
cfg_options
:
raise
ValueError
(
raise
ValueError
(
'--options and --cfg-options cannot be both specified, '
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options'
)
'--options is deprecated in favor of --cfg-options'
)
if
args
.
options
:
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
args
.
cfg_options
=
args
.
options
args
.
cfg_options
=
args
.
options
return
args
return
args
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
# import modules, registry will be updated
# import modules, registry will be updated
import
sys
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
'.'
))
if
hasattr
(
cfg
,
'plugin'
):
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
if
cfg
.
plugin
:
import
importlib
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
if
hasattr
(
cfg
,
'plugin_dir'
):
def
import_path
(
plugin_dir
):
def
import_path
(
plugin_dir
):
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plugin_dirs
=
cfg
.
plugin_dir
plugin_dirs
=
cfg
.
plugin_dir
if
not
isinstance
(
plugin_dirs
,
list
):
if
not
isinstance
(
plugin_dirs
,
list
):
plugin_dirs
=
[
plugin_dirs
,]
plugin_dirs
=
[
plugin_dirs
,]
for
plugin_dir
in
plugin_dirs
:
for
plugin_dir
in
plugin_dirs
:
import_path
(
plugin_dir
)
import_path
(
plugin_dir
)
else
:
else
:
# import dir is the dirpath for the config file
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
_module_path
=
_module_path
+
'.'
+
m
print
(
f
'importing
{
_module_path
}
/'
)
print
(
f
'importing
{
_module_path
}
/'
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# work_dir is determined in this priority: CLI > segment in file > filename
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpu_ids
is
not
None
:
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
cfg
.
gpu_ids
=
args
.
gpu_ids
else
:
else
:
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
if
args
.
autoscale_lr
:
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
distributed
=
False
distributed
=
False
else
:
else
:
distributed
=
True
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
# specify logger name, if we still use 'mmdet', the output info will be
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
# TODO: ugly workaround to judge whether we are training det or seg model
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
logger_name
=
'mmseg'
logger_name
=
'mmseg'
else
:
else
:
logger_name
=
'mmdet'
logger_name
=
'mmdet'
logger
=
get_root_logger
(
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
# init the meta dict to record some important information such as
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
# environment info and seed, which will be logged
meta
=
dict
()
meta
=
dict
()
# log env info
# log env info
env_info_dict
=
collect_env
()
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
# set random seeds
if
args
.
seed
is
not
None
:
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
args
.
seed
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
model
=
build_model
(
cfg
.
model
,
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
.
init_weights
()
model
.
init_weights
()
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
cfg
.
data
.
train
.
work_dir
=
cfg
.
work_dir
cfg
.
data
.
train
.
work_dir
=
cfg
.
work_dir
cfg
.
data
.
val
.
work_dir
=
cfg
.
work_dir
cfg
.
data
.
val
.
work_dir
=
cfg
.
work_dir
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
# in case we use a dataset wrapper
# in case we use a dataset wrapper
if
'dataset'
in
cfg
.
data
.
train
:
if
'dataset'
in
cfg
.
data
.
train
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
else
:
else
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
# set test_mode=False here in deep copied config
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset
.
test_mode
=
False
val_dataset
.
test_mode
=
False
datasets
.
append
(
build_dataset
(
val_dataset
))
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version, config file content and class names in
# save mmdet version, config file content and class names in
# checkpoints as meta data
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
mmdet_version
,
mmdet_version
=
mmdet_version
,
mmseg_version
=
mmseg_version
,
mmseg_version
=
mmseg_version
,
mmdet3d_version
=
mmdet3d_version
,
mmdet3d_version
=
mmdet3d_version
,
config
=
cfg
.
pretty_text
,
config
=
cfg
.
pretty_text
,
CLASSES
=
None
,
CLASSES
=
None
,
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
# add an attribute for visualization convenience
# add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES
# model.CLASSES = datasets[0].CLASSES
train_model
(
train_model
(
model
,
model
,
datasets
,
datasets
,
cfg
,
cfg
,
distributed
=
distributed
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
timestamp
=
timestamp
,
meta
=
meta
)
meta
=
meta
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
autonomous_driving/Online-HD-Map-Construction
-CVPR2023
/tools/visualization/renderer.py
→
autonomous_driving/Online-HD-Map-Construction/tools/visualization/renderer.py
View file @
f3b13cad
import
os.path
as
osp
import
os.path
as
osp
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
copy
import
copy
import
cv2
import
cv2
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
from
PIL
import
Image
from
PIL
import
Image
from
shapely.geometry
import
LineString
from
shapely.geometry
import
LineString
def
remove_nan_values
(
uv
):
def
remove_nan_values
(
uv
):
is_u_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
0
]))
is_u_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
0
]))
is_v_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
1
]))
is_v_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
1
]))
is_uv_valid
=
np
.
logical_and
(
is_u_valid
,
is_v_valid
)
is_uv_valid
=
np
.
logical_and
(
is_u_valid
,
is_v_valid
)
uv_valid
=
uv
[
is_uv_valid
]
uv_valid
=
uv
[
is_uv_valid
]
return
uv_valid
return
uv_valid
def
points_ego2img
(
pts_ego
,
extrinsics
,
intrinsics
):
def
points_ego2img
(
pts_ego
,
extrinsics
,
intrinsics
):
pts_ego_4d
=
np
.
concatenate
([
pts_ego
,
np
.
ones
([
len
(
pts_ego
),
1
])],
axis
=-
1
)
pts_ego_4d
=
np
.
concatenate
([
pts_ego
,
np
.
ones
([
len
(
pts_ego
),
1
])],
axis
=-
1
)
pts_cam_4d
=
extrinsics
@
pts_ego_4d
.
T
pts_cam_4d
=
extrinsics
@
pts_ego_4d
.
T
uv
=
(
intrinsics
@
pts_cam_4d
[:
3
,
:]).
T
uv
=
(
intrinsics
@
pts_cam_4d
[:
3
,
:]).
T
uv
=
remove_nan_values
(
uv
)
uv
=
remove_nan_values
(
uv
)
depth
=
uv
[:,
2
]
depth
=
uv
[:,
2
]
uv
=
uv
[:,
:
2
]
/
uv
[:,
2
].
reshape
(
-
1
,
1
)
uv
=
uv
[:,
:
2
]
/
uv
[:,
2
].
reshape
(
-
1
,
1
)
return
uv
,
depth
return
uv
,
depth
def
interp_fixed_dist
(
line
,
sample_dist
):
def
interp_fixed_dist
(
line
,
sample_dist
):
''' Interpolate a line at fixed interval.
''' Interpolate a line at fixed interval.
Args:
Args:
line (LineString): line
line (LineString): line
sample_dist (float): sample interval
sample_dist (float): sample interval
Returns:
Returns:
points (array): interpolated points, shape (N, 2)
points (array): interpolated points, shape (N, 2)
'''
'''
distances
=
list
(
np
.
arange
(
sample_dist
,
line
.
length
,
sample_dist
))
distances
=
list
(
np
.
arange
(
sample_dist
,
line
.
length
,
sample_dist
))
# make sure to sample at least two points when sample_dist > line.length
# make sure to sample at least two points when sample_dist > line.length
distances
=
[
0
,]
+
distances
+
[
line
.
length
,]
distances
=
[
0
,]
+
distances
+
[
line
.
length
,]
sampled_points
=
np
.
array
([
list
(
line
.
interpolate
(
distance
).
coords
)
sampled_points
=
np
.
array
([
list
(
line
.
interpolate
(
distance
).
coords
)
for
distance
in
distances
]).
squeeze
()
for
distance
in
distances
]).
squeeze
()
return
sampled_points
return
sampled_points
def
draw_polyline_ego_on_img
(
polyline_ego
,
img_bgr
,
extrinsics
,
intrinsics
,
color_bgr
,
thickness
):
def
draw_polyline_ego_on_img
(
polyline_ego
,
img_bgr
,
extrinsics
,
intrinsics
,
color_bgr
,
thickness
):
# if 2-dimension, assume z=0
# if 2-dimension, assume z=0
if
polyline_ego
.
shape
[
1
]
==
2
:
if
polyline_ego
.
shape
[
1
]
==
2
:
zeros
=
np
.
zeros
((
polyline_ego
.
shape
[
0
],
1
))
zeros
=
np
.
zeros
((
polyline_ego
.
shape
[
0
],
1
))
polyline_ego
=
np
.
concatenate
([
polyline_ego
,
zeros
],
axis
=
1
)
polyline_ego
=
np
.
concatenate
([
polyline_ego
,
zeros
],
axis
=
1
)
polyline_ego
=
interp_fixed_dist
(
line
=
LineString
(
polyline_ego
),
sample_dist
=
0.2
)
polyline_ego
=
interp_fixed_dist
(
line
=
LineString
(
polyline_ego
),
sample_dist
=
0.2
)
uv
,
depth
=
points_ego2img
(
polyline_ego
,
extrinsics
,
intrinsics
)
uv
,
depth
=
points_ego2img
(
polyline_ego
,
extrinsics
,
intrinsics
)
h
,
w
,
c
=
img_bgr
.
shape
h
,
w
,
c
=
img_bgr
.
shape
is_valid_x
=
np
.
logical_and
(
0
<=
uv
[:,
0
],
uv
[:,
0
]
<
w
-
1
)
is_valid_x
=
np
.
logical_and
(
0
<=
uv
[:,
0
],
uv
[:,
0
]
<
w
-
1
)
is_valid_y
=
np
.
logical_and
(
0
<=
uv
[:,
1
],
uv
[:,
1
]
<
h
-
1
)
is_valid_y
=
np
.
logical_and
(
0
<=
uv
[:,
1
],
uv
[:,
1
]
<
h
-
1
)
is_valid_z
=
depth
>
0
is_valid_z
=
depth
>
0
is_valid_points
=
np
.
logical_and
.
reduce
([
is_valid_x
,
is_valid_y
,
is_valid_z
])
is_valid_points
=
np
.
logical_and
.
reduce
([
is_valid_x
,
is_valid_y
,
is_valid_z
])
if
is_valid_points
.
sum
()
==
0
:
if
is_valid_points
.
sum
()
==
0
:
return
return
tmp_list
=
[]
tmp_list
=
[]
for
i
,
valid
in
enumerate
(
is_valid_points
):
for
i
,
valid
in
enumerate
(
is_valid_points
):
if
valid
:
if
valid
:
tmp_list
.
append
(
uv
[
i
])
tmp_list
.
append
(
uv
[
i
])
else
:
else
:
if
len
(
tmp_list
)
>=
2
:
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
image
=
img_bgr
,
color
=
color_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
thickness_px
=
thickness
,
)
)
tmp_list
=
[]
tmp_list
=
[]
if
len
(
tmp_list
)
>=
2
:
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
image
=
img_bgr
,
color
=
color_bgr
,
color
=
color_bgr
,
thickness_px
=
thickness
,
thickness_px
=
thickness
,
)
)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2(
# draw_visible_polyline_cv2(
# copy.deepcopy(uv),
# copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr,
# image=img_bgr,
# color=color_bgr,
# color=color_bgr,
# thickness_px=thickness,
# thickness_px=thickness,
# )
# )
def
draw_visible_polyline_cv2
(
line
,
valid_pts_bool
,
image
,
color
,
thickness_px
):
def
draw_visible_polyline_cv2
(
line
,
valid_pts_bool
,
image
,
color
,
thickness_px
):
"""Draw a polyline onto an image using given line segments.
"""Draw a polyline onto an image using given line segments.
Args:
Args:
line: Array of shape (K, 2) representing the coordinates of line.
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
"""
line
=
np
.
round
(
line
).
astype
(
int
)
# type: ignore
line
=
np
.
round
(
line
).
astype
(
int
)
# type: ignore
for
i
in
range
(
len
(
line
)
-
1
):
for
i
in
range
(
len
(
line
)
-
1
):
if
(
not
valid_pts_bool
[
i
])
or
(
not
valid_pts_bool
[
i
+
1
]):
if
(
not
valid_pts_bool
[
i
])
or
(
not
valid_pts_bool
[
i
+
1
]):
continue
continue
x1
=
line
[
i
][
0
]
x1
=
line
[
i
][
0
]
y1
=
line
[
i
][
1
]
y1
=
line
[
i
][
1
]
x2
=
line
[
i
+
1
][
0
]
x2
=
line
[
i
+
1
][
0
]
y2
=
line
[
i
+
1
][
1
]
y2
=
line
[
i
+
1
][
1
]
# Use anti-aliasing (AA) for curves
# Use anti-aliasing (AA) for curves
image
=
cv2
.
line
(
image
,
pt1
=
(
x1
,
y1
),
pt2
=
(
x2
,
y2
),
color
=
color
,
thickness
=
thickness_px
,
lineType
=
cv2
.
LINE_AA
)
image
=
cv2
.
line
(
image
,
pt1
=
(
x1
,
y1
),
pt2
=
(
x2
,
y2
),
color
=
color
,
thickness
=
thickness_px
,
lineType
=
cv2
.
LINE_AA
)
COLOR_MAPS_BGR
=
{
COLOR_MAPS_BGR
=
{
# bgr colors
# bgr colors
'divider'
:
(
0
,
0
,
255
),
'divider'
:
(
0
,
0
,
255
),
'boundary'
:
(
0
,
255
,
0
),
'boundary'
:
(
0
,
255
,
0
),
'ped_crossing'
:
(
255
,
0
,
0
),
'ped_crossing'
:
(
255
,
0
,
0
),
'centerline'
:
(
51
,
183
,
255
),
'centerline'
:
(
51
,
183
,
255
),
'drivable_area'
:
(
171
,
255
,
255
)
'drivable_area'
:
(
171
,
255
,
255
)
}
}
COLOR_MAPS_PLT
=
{
COLOR_MAPS_PLT
=
{
'divider'
:
'r'
,
'divider'
:
'r'
,
'boundary'
:
'g'
,
'boundary'
:
'g'
,
'ped_crossing'
:
'b'
,
'ped_crossing'
:
'b'
,
'centerline'
:
'orange'
,
'centerline'
:
'orange'
,
'drivable_area'
:
'y'
,
'drivable_area'
:
'y'
,
}
}
CAM_NAMES_AV2
=
[
'ring_front_center'
,
'ring_front_right'
,
'ring_front_left'
,
CAM_NAMES_AV2
=
[
'ring_front_center'
,
'ring_front_right'
,
'ring_front_left'
,
'ring_rear_right'
,
'ring_rear_left'
,
'ring_side_right'
,
'ring_side_left'
,
'ring_rear_right'
,
'ring_rear_left'
,
'ring_side_right'
,
'ring_side_left'
,
]
]
class
Renderer
(
object
):
class
Renderer
(
object
):
"""Render map elements on image views.
"""Render map elements on image views.
Args:
Args:
roi_size (tuple): bev range
roi_size (tuple): bev range
"""
"""
def
__init__
(
self
,
roi_size
):
def
__init__
(
self
,
roi_size
):
self
.
roi_size
=
roi_size
self
.
roi_size
=
roi_size
def
render_bev_from_vectors
(
self
,
vectors
,
out_dir
):
def
render_bev_from_vectors
(
self
,
vectors
,
out_dir
):
'''Plot vectorized map elements on BEV.
'''Plot vectorized map elements on BEV.
Args:
Args:
vectors (dict): dict of vectorized map elements.
vectors (dict): dict of vectorized map elements.
out_dir (str): output directory
out_dir (str): output directory
'''
'''
car_img
=
Image
.
open
(
'resources/images/car.png'
)
car_img
=
Image
.
open
(
'resources/images/car.png'
)
map_path
=
os
.
path
.
join
(
out_dir
,
'map.jpg'
)
map_path
=
os
.
path
.
join
(
out_dir
,
'map.jpg'
)
plt
.
figure
(
figsize
=
(
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]))
plt
.
figure
(
figsize
=
(
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]))
plt
.
xlim
(
-
self
.
roi_size
[
0
]
/
2
-
1
,
self
.
roi_size
[
0
]
/
2
+
1
)
plt
.
xlim
(
-
self
.
roi_size
[
0
]
/
2
-
1
,
self
.
roi_size
[
0
]
/
2
+
1
)
plt
.
ylim
(
-
self
.
roi_size
[
1
]
/
2
-
1
,
self
.
roi_size
[
1
]
/
2
+
1
)
plt
.
ylim
(
-
self
.
roi_size
[
1
]
/
2
-
1
,
self
.
roi_size
[
1
]
/
2
+
1
)
plt
.
axis
(
'off'
)
plt
.
axis
(
'off'
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
for
cat
,
vector_list
in
vectors
.
items
():
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_PLT
[
cat
]
color
=
COLOR_MAPS_PLT
[
cat
]
for
vector
in
vector_list
:
for
vector
in
vector_list
:
pts
=
np
.
array
(
vector
)[:,
:
2
]
pts
=
np
.
array
(
vector
)[:,
:
2
]
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1)
# scale_units='xy', scale=1)
plt
.
plot
(
x
,
y
,
color
=
color
,
linewidth
=
5
,
marker
=
'o'
,
linestyle
=
'-'
,
markersize
=
20
)
plt
.
plot
(
x
,
y
,
color
=
color
,
linewidth
=
5
,
marker
=
'o'
,
linestyle
=
'-'
,
markersize
=
20
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
dpi
=
40
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
dpi
=
40
)
plt
.
close
()
plt
.
close
()
def
render_camera_views_from_vectors
(
self
,
vectors
,
imgs
,
extrinsics
,
def
render_camera_views_from_vectors
(
self
,
vectors
,
imgs
,
extrinsics
,
intrinsics
,
thickness
,
out_dir
):
intrinsics
,
thickness
,
out_dir
):
'''Project vectorized map elements to camera views.
'''Project vectorized map elements to camera views.
Args:
Args:
vectors (dict): dict of vectorized map elements.
vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color.
imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4)
extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3)
intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images.
thickness (int): thickness of lines to draw on images.
out_dir (str): output directory
out_dir (str): output directory
'''
'''
for
i
in
range
(
len
(
imgs
)):
for
i
in
range
(
len
(
imgs
)):
img
=
imgs
[
i
]
img
=
imgs
[
i
]
extrinsic
=
extrinsics
[
i
]
extrinsic
=
extrinsics
[
i
]
intrinsic
=
intrinsics
[
i
]
intrinsic
=
intrinsics
[
i
]
img_bgr
=
copy
.
deepcopy
(
img
)
img_bgr
=
copy
.
deepcopy
(
img
)
for
cat
,
vector_list
in
vectors
.
items
():
for
cat
,
vector_list
in
vectors
.
items
():
color
=
COLOR_MAPS_BGR
[
cat
]
color
=
COLOR_MAPS_BGR
[
cat
]
for
vector
in
vector_list
:
for
vector
in
vector_list
:
img_bgr
=
np
.
ascontiguousarray
(
img_bgr
)
img_bgr
=
np
.
ascontiguousarray
(
img_bgr
)
vector_array
=
np
.
array
(
vector
)
vector_array
=
np
.
array
(
vector
)
if
vector_array
.
shape
[
1
]
>
3
:
if
vector_array
.
shape
[
1
]
>
3
:
vector_array
=
vector_array
[:,
:
3
]
vector_array
=
vector_array
[:,
:
3
]
draw_polyline_ego_on_img
(
vector_array
,
img_bgr
,
extrinsic
,
intrinsic
,
draw_polyline_ego_on_img
(
vector_array
,
img_bgr
,
extrinsic
,
intrinsic
,
color
,
thickness
)
color
,
thickness
)
out_path
=
osp
.
join
(
out_dir
,
CAM_NAMES_AV2
[
i
])
+
'.jpg'
out_path
=
osp
.
join
(
out_dir
,
CAM_NAMES_AV2
[
i
])
+
'.jpg'
cv2
.
imwrite
(
out_path
,
img_bgr
)
cv2
.
imwrite
(
out_path
,
img_bgr
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment