Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
cb0dd8ee
"blogs/vscode:/vscode.git/clone" did not exist on "c25a91b60c5192065dfdcabd373b947aa2234fe1"
Commit
cb0dd8ee
authored
Jul 13, 2019
by
Cao Yuhang
Committed by
Kai Chen
Jul 13, 2019
Browse files
support fp16 for maskiou_head (#986)
parent
713e98bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
0 deletions
+5
-0
mmdet/models/mask_heads/maskiou_head.py
mmdet/models/mask_heads/maskiou_head.py
+5
-0
No files found.
mmdet/models/mask_heads/maskiou_head.py
View file @
cb0dd8ee
...
@@ -2,6 +2,7 @@ import numpy as np
...
@@ -2,6 +2,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
kaiming_init
,
normal_init
from
mmcv.cnn
import
kaiming_init
,
normal_init
from
mmdet.core
import
force_fp32
from
..builder
import
build_loss
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..registry
import
HEADS
...
@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module):
...
@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module):
self
.
conv_out_channels
=
conv_out_channels
self
.
conv_out_channels
=
conv_out_channels
self
.
fc_out_channels
=
fc_out_channels
self
.
fc_out_channels
=
fc_out_channels
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
fp16_enabled
=
False
self
.
convs
=
nn
.
ModuleList
()
self
.
convs
=
nn
.
ModuleList
()
for
i
in
range
(
num_convs
):
for
i
in
range
(
num_convs
):
...
@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module):
...
@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module):
mask_iou
=
self
.
fc_mask_iou
(
x
)
mask_iou
=
self
.
fc_mask_iou
(
x
)
return
mask_iou
return
mask_iou
@
force_fp32
(
apply_to
=
(
'mask_iou_pred'
,
))
def
loss
(
self
,
mask_iou_pred
,
mask_iou_targets
):
def
loss
(
self
,
mask_iou_pred
,
mask_iou_targets
):
pos_inds
=
mask_iou_targets
>
0
pos_inds
=
mask_iou_targets
>
0
if
pos_inds
.
sum
()
>
0
:
if
pos_inds
.
sum
()
>
0
:
...
@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module):
...
@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module):
loss_mask_iou
=
mask_iou_pred
*
0
loss_mask_iou
=
mask_iou_pred
*
0
return
dict
(
loss_mask_iou
=
loss_mask_iou
)
return
dict
(
loss_mask_iou
=
loss_mask_iou
)
@
force_fp32
(
apply_to
=
(
'mask_pred'
,
))
def
get_target
(
self
,
sampling_results
,
gt_masks
,
mask_pred
,
mask_targets
,
def
get_target
(
self
,
sampling_results
,
gt_masks
,
mask_pred
,
mask_targets
,
rcnn_train_cfg
):
rcnn_train_cfg
):
"""Compute target of mask IoU.
"""Compute target of mask IoU.
...
@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module):
...
@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module):
area_ratios
=
pos_proposals
.
new_zeros
((
0
,
))
area_ratios
=
pos_proposals
.
new_zeros
((
0
,
))
return
area_ratios
return
area_ratios
@
force_fp32
(
apply_to
=
(
'mask_iou_pred'
,
))
def
get_mask_scores
(
self
,
mask_iou_pred
,
det_bboxes
,
det_labels
):
def
get_mask_scores
(
self
,
mask_iou_pred
,
det_bboxes
,
det_labels
):
"""Get the mask scores.
"""Get the mask scores.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment