Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
169f098d
Commit
169f098d
authored
Mar 22, 2022
by
Yifei Yang
Committed by
zhouzaida
Jul 19, 2022
Browse files
[Fix] Set keypoints not in the cropped image invisible (#1804)
* set invisiblity * fix as comment
parent
6534efd6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
mmcv/transforms/processing.py
mmcv/transforms/processing.py
+10
-1
tests/test_transforms/test_transforms_processing.py
tests/test_transforms/test_transforms_processing.py
+7
-7
No files found.
mmcv/transforms/processing.py
View file @
169f098d
...
...
@@ -531,7 +531,8 @@ class CenterCrop(BaseTransform):
results
[
'gt_bboxes'
]
=
gt_bboxes
def
_crop_keypoints
(
self
,
results
:
dict
,
bboxes
:
np
.
ndarray
)
->
None
:
"""Update key points according to CenterCrop.
"""Update key points according to CenterCrop. Keypoints that not in the
cropped image will be set invisible.
Args:
results (dict): Result dict contains the data to transform.
...
...
@@ -544,6 +545,14 @@ class CenterCrop(BaseTransform):
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
# NK = number of points per object
gt_keypoints
=
results
[
'gt_keypoints'
]
-
keypoints_offset
# set gt_kepoints out of the result image invisible
height
,
width
=
results
[
'img'
].
shape
[:
2
]
valid_pos
=
(
gt_keypoints
[:,
:,
0
]
>=
0
)
*
(
gt_keypoints
[:,
:,
0
]
<
width
)
*
(
gt_keypoints
[:,
:,
1
]
>=
0
)
*
(
gt_keypoints
[:,
:,
1
]
<
height
)
gt_keypoints
[:,
:,
2
]
=
np
.
where
(
valid_pos
,
gt_keypoints
[:,
:,
2
],
0
)
gt_keypoints
[:,
:,
0
]
=
np
.
clip
(
gt_keypoints
[:,
:,
0
],
0
,
results
[
'img'
].
shape
[
1
])
gt_keypoints
[:,
:,
1
]
=
np
.
clip
(
gt_keypoints
[:,
:,
1
],
0
,
...
...
tests/test_transforms/test_transforms_processing.py
View file @
169f098d
...
...
@@ -302,7 +302,7 @@ class TestCenterCrop:
224
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
12
,
1
]],
[[
112
,
112
,
1
]],
[[
212
,
187
,
1
]]])).
all
()
np
.
array
([[[
0
,
12
,
0
]],
[[
112
,
112
,
1
]],
[[
212
,
187
,
1
]]])).
all
()
# test CenterCrop when size is tuple
transform
=
dict
(
type
=
'CenterCrop'
,
crop_size
=
(
224
,
224
))
...
...
@@ -321,7 +321,7 @@ class TestCenterCrop:
224
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
12
,
1
]],
[[
112
,
112
,
1
]],
[[
212
,
187
,
1
]]])).
all
()
np
.
array
([[[
0
,
12
,
0
]],
[[
112
,
112
,
1
]],
[[
212
,
187
,
1
]]])).
all
()
# test CenterCrop when crop_height != crop_width
transform
=
dict
(
type
=
'CenterCrop'
,
crop_size
=
(
224
,
256
))
...
...
@@ -340,7 +340,7 @@ class TestCenterCrop:
256
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
28
,
1
]],
[[
112
,
128
,
1
]],
[[
212
,
203
,
1
]]])).
all
()
np
.
array
([[[
0
,
28
,
0
]],
[[
112
,
128
,
1
]],
[[
212
,
203
,
1
]]])).
all
()
# test CenterCrop when crop_size is equal to img.shape
img_height
,
img_width
,
_
=
self
.
original_img
.
shape
...
...
@@ -398,7 +398,7 @@ class TestCenterCrop:
300
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
50
,
1
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
1
]]])).
all
()
np
.
array
([[[
0
,
50
,
0
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
0
]]])).
all
()
transform
=
dict
(
type
=
'CenterCrop'
,
...
...
@@ -418,7 +418,7 @@ class TestCenterCrop:
300
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
50
,
1
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
1
]]])).
all
()
np
.
array
([[[
0
,
50
,
0
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
0
]]])).
all
()
# test CenterCrop when crop_width is smaller than img_width
transform
=
dict
(
...
...
@@ -438,7 +438,7 @@ class TestCenterCrop:
300
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
0
,
50
,
1
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
1
]]])).
all
()
np
.
array
([[[
0
,
50
,
0
]],
[[
100
,
150
,
1
]],
[[
200
,
225
,
0
]]])).
all
()
# test CenterCrop when crop_height is smaller than img_height
transform
=
dict
(
...
...
@@ -457,7 +457,7 @@ class TestCenterCrop:
150
]])).
all
()
assert
np
.
equal
(
results
[
'gt_keypoints'
],
np
.
array
([[[
20
,
0
,
1
]],
[[
200
,
75
,
1
]],
[[
300
,
150
,
1
]]])).
all
()
np
.
array
([[[
20
,
0
,
0
]],
[[
200
,
75
,
1
]],
[[
300
,
150
,
0
]]])).
all
()
@
pytest
.
mark
.
skipif
(
condition
=
torch
is
None
,
reason
=
'No torch in current env'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment