Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InstructBLIP_pytorch
Commits
c04f261a
Commit
c04f261a
authored
Aug 22, 2024
by
dongchy920
Browse files
InstruceBLIP
parents
Pipeline
#1594
canceled with stages
Changes
421
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2116 additions
and
0 deletions
+2116
-0
lavis/common/annotator/uniformer/mmcv/video/processing.py
lavis/common/annotator/uniformer/mmcv/video/processing.py
+160
-0
lavis/common/annotator/uniformer/mmcv/visualization/__init__.py
...common/annotator/uniformer/mmcv/visualization/__init__.py
+9
-0
lavis/common/annotator/uniformer/mmcv/visualization/color.py
lavis/common/annotator/uniformer/mmcv/visualization/color.py
+51
-0
lavis/common/annotator/uniformer/mmcv/visualization/image.py
lavis/common/annotator/uniformer/mmcv/visualization/image.py
+152
-0
lavis/common/annotator/uniformer/mmcv/visualization/optflow.py
.../common/annotator/uniformer/mmcv/visualization/optflow.py
+112
-0
lavis/common/annotator/uniformer/mmcv_custom/__init__.py
lavis/common/annotator/uniformer/mmcv_custom/__init__.py
+6
-0
lavis/common/annotator/uniformer/mmcv_custom/checkpoint.py
lavis/common/annotator/uniformer/mmcv_custom/checkpoint.py
+501
-0
lavis/common/annotator/uniformer/mmseg/apis/__init__.py
lavis/common/annotator/uniformer/mmseg/apis/__init__.py
+9
-0
lavis/common/annotator/uniformer/mmseg/apis/inference.py
lavis/common/annotator/uniformer/mmseg/apis/inference.py
+136
-0
lavis/common/annotator/uniformer/mmseg/apis/test.py
lavis/common/annotator/uniformer/mmseg/apis/test.py
+238
-0
lavis/common/annotator/uniformer/mmseg/apis/train.py
lavis/common/annotator/uniformer/mmseg/apis/train.py
+116
-0
lavis/common/annotator/uniformer/mmseg/core/__init__.py
lavis/common/annotator/uniformer/mmseg/core/__init__.py
+3
-0
lavis/common/annotator/uniformer/mmseg/core/evaluation/__init__.py
...mon/annotator/uniformer/mmseg/core/evaluation/__init__.py
+8
-0
lavis/common/annotator/uniformer/mmseg/core/evaluation/class_names.py
.../annotator/uniformer/mmseg/core/evaluation/class_names.py
+152
-0
lavis/common/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
...n/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
+109
-0
lavis/common/annotator/uniformer/mmseg/core/evaluation/metrics.py
...mmon/annotator/uniformer/mmseg/core/evaluation/metrics.py
+326
-0
lavis/common/annotator/uniformer/mmseg/core/seg/__init__.py
lavis/common/annotator/uniformer/mmseg/core/seg/__init__.py
+4
-0
lavis/common/annotator/uniformer/mmseg/core/seg/builder.py
lavis/common/annotator/uniformer/mmseg/core/seg/builder.py
+8
-0
lavis/common/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
...on/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
+4
-0
lavis/common/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
...or/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
+12
-0
No files found.
Too many changes to show.
To preserve performance only
421 of 421+
files are displayed.
Plain diff
Email patch
lavis/common/annotator/uniformer/mmcv/video/processing.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
as
osp
import
subprocess
import
tempfile
from
annotator.uniformer.mmcv.utils
import
requires_executable
@
requires_executable
(
'ffmpeg'
)
def
convert_video
(
in_file
,
out_file
,
print_cmd
=
False
,
pre_options
=
''
,
**
kwargs
):
"""Convert a video with ffmpeg.
This provides a general api to ffmpeg, the executed command is::
`ffmpeg -y <pre_options> -i <in_file> <options> <out_file>`
Options(kwargs) are mapped to ffmpeg commands with the following rules:
- key=val: "-key val"
- key=True: "-key"
- key=False: ""
Args:
in_file (str): Input video filename.
out_file (str): Output video filename.
pre_options (str): Options appears before "-i <in_file>".
print_cmd (bool): Whether to print the final ffmpeg command.
"""
options
=
[]
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
bool
):
if
v
:
options
.
append
(
f
'-
{
k
}
'
)
elif
k
==
'log_level'
:
assert
v
in
[
'quiet'
,
'panic'
,
'fatal'
,
'error'
,
'warning'
,
'info'
,
'verbose'
,
'debug'
,
'trace'
]
options
.
append
(
f
'-loglevel
{
v
}
'
)
else
:
options
.
append
(
f
'-
{
k
}
{
v
}
'
)
cmd
=
f
'ffmpeg -y
{
pre_options
}
-i
{
in_file
}
{
" "
.
join
(
options
)
}
'
\
f
'
{
out_file
}
'
if
print_cmd
:
print
(
cmd
)
subprocess
.
call
(
cmd
,
shell
=
True
)
@
requires_executable
(
'ffmpeg'
)
def
resize_video
(
in_file
,
out_file
,
size
=
None
,
ratio
=
None
,
keep_ar
=
False
,
log_level
=
'info'
,
print_cmd
=
False
):
"""Resize a video.
Args:
in_file (str): Input video filename.
out_file (str): Output video filename.
size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
ratio (tuple or float): Expected resize ratio, (2, 0.5) means
(w*2, h*0.5).
keep_ar (bool): Whether to keep original aspect ratio.
log_level (str): Logging level of ffmpeg.
print_cmd (bool): Whether to print the final ffmpeg command.
"""
if
size
is
None
and
ratio
is
None
:
raise
ValueError
(
'expected size or ratio must be specified'
)
if
size
is
not
None
and
ratio
is
not
None
:
raise
ValueError
(
'size and ratio cannot be specified at the same time'
)
options
=
{
'log_level'
:
log_level
}
if
size
:
if
not
keep_ar
:
options
[
'vf'
]
=
f
'scale=
{
size
[
0
]
}
:
{
size
[
1
]
}
'
else
:
options
[
'vf'
]
=
f
'scale=w=
{
size
[
0
]
}
:h=
{
size
[
1
]
}
:'
\
'force_original_aspect_ratio=decrease'
else
:
if
not
isinstance
(
ratio
,
tuple
):
ratio
=
(
ratio
,
ratio
)
options
[
'vf'
]
=
f
'scale="trunc(iw*
{
ratio
[
0
]
}
):trunc(ih*
{
ratio
[
1
]
}
)"'
convert_video
(
in_file
,
out_file
,
print_cmd
,
**
options
)
@
requires_executable
(
'ffmpeg'
)
def
cut_video
(
in_file
,
out_file
,
start
=
None
,
end
=
None
,
vcodec
=
None
,
acodec
=
None
,
log_level
=
'info'
,
print_cmd
=
False
):
"""Cut a clip from a video.
Args:
in_file (str): Input video filename.
out_file (str): Output video filename.
start (None or float): Start time (in seconds).
end (None or float): End time (in seconds).
vcodec (None or str): Output video codec, None for unchanged.
acodec (None or str): Output audio codec, None for unchanged.
log_level (str): Logging level of ffmpeg.
print_cmd (bool): Whether to print the final ffmpeg command.
"""
options
=
{
'log_level'
:
log_level
}
if
vcodec
is
None
:
options
[
'vcodec'
]
=
'copy'
if
acodec
is
None
:
options
[
'acodec'
]
=
'copy'
if
start
:
options
[
'ss'
]
=
start
else
:
start
=
0
if
end
:
options
[
't'
]
=
end
-
start
convert_video
(
in_file
,
out_file
,
print_cmd
,
**
options
)
@
requires_executable
(
'ffmpeg'
)
def
concat_video
(
video_list
,
out_file
,
vcodec
=
None
,
acodec
=
None
,
log_level
=
'info'
,
print_cmd
=
False
):
"""Concatenate multiple videos into a single one.
Args:
video_list (list): A list of video filenames
out_file (str): Output video filename
vcodec (None or str): Output video codec, None for unchanged
acodec (None or str): Output audio codec, None for unchanged
log_level (str): Logging level of ffmpeg.
print_cmd (bool): Whether to print the final ffmpeg command.
"""
tmp_filehandler
,
tmp_filename
=
tempfile
.
mkstemp
(
suffix
=
'.txt'
,
text
=
True
)
with
open
(
tmp_filename
,
'w'
)
as
f
:
for
filename
in
video_list
:
f
.
write
(
f
'file
{
osp
.
abspath
(
filename
)
}
\n
'
)
options
=
{
'log_level'
:
log_level
}
if
vcodec
is
None
:
options
[
'vcodec'
]
=
'copy'
if
acodec
is
None
:
options
[
'acodec'
]
=
'copy'
convert_video
(
tmp_filename
,
out_file
,
print_cmd
,
pre_options
=
'-f concat -safe 0'
,
**
options
)
os
.
close
(
tmp_filehandler
)
os
.
remove
(
tmp_filename
)
lavis/common/annotator/uniformer/mmcv/visualization/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.color
import
Color
,
color_val
from
.image
import
imshow
,
imshow_bboxes
,
imshow_det_bboxes
from
.optflow
import
flow2rgb
,
flowshow
,
make_color_wheel
__all__
=
[
'Color'
,
'color_val'
,
'imshow'
,
'imshow_bboxes'
,
'imshow_det_bboxes'
,
'flowshow'
,
'flow2rgb'
,
'make_color_wheel'
]
lavis/common/annotator/uniformer/mmcv/visualization/color.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
enum
import
Enum
import
numpy
as
np
from
annotator.uniformer.mmcv.utils
import
is_str
class
Color
(
Enum
):
"""An enum that defines common colors.
Contains red, green, blue, cyan, yellow, magenta, white and black.
"""
red
=
(
0
,
0
,
255
)
green
=
(
0
,
255
,
0
)
blue
=
(
255
,
0
,
0
)
cyan
=
(
255
,
255
,
0
)
yellow
=
(
0
,
255
,
255
)
magenta
=
(
255
,
0
,
255
)
white
=
(
255
,
255
,
255
)
black
=
(
0
,
0
,
0
)
def
color_val
(
color
):
"""Convert various input to color tuples.
Args:
color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[int]: A tuple of 3 integers indicating BGR channels.
"""
if
is_str
(
color
):
return
Color
[
color
].
value
elif
isinstance
(
color
,
Color
):
return
color
.
value
elif
isinstance
(
color
,
tuple
):
assert
len
(
color
)
==
3
for
channel
in
color
:
assert
0
<=
channel
<=
255
return
color
elif
isinstance
(
color
,
int
):
assert
0
<=
color
<=
255
return
color
,
color
,
color
elif
isinstance
(
color
,
np
.
ndarray
):
assert
color
.
ndim
==
1
and
color
.
size
==
3
assert
np
.
all
((
color
>=
0
)
&
(
color
<=
255
))
color
=
color
.
astype
(
np
.
uint8
)
return
tuple
(
color
)
else
:
raise
TypeError
(
f
'Invalid type for color:
{
type
(
color
)
}
'
)
lavis/common/annotator/uniformer/mmcv/visualization/image.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
cv2
import
numpy
as
np
from
annotator.uniformer.mmcv.image
import
imread
,
imwrite
from
.color
import
color_val
def
imshow
(
img
,
win_name
=
''
,
wait_time
=
0
):
"""Show an image.
Args:
img (str or ndarray): The image to be displayed.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
"""
cv2
.
imshow
(
win_name
,
imread
(
img
))
if
wait_time
==
0
:
# prevent from hanging if windows was closed
while
True
:
ret
=
cv2
.
waitKey
(
1
)
closed
=
cv2
.
getWindowProperty
(
win_name
,
cv2
.
WND_PROP_VISIBLE
)
<
1
# if user closed window or if some key pressed
if
closed
or
ret
!=
-
1
:
break
else
:
ret
=
cv2
.
waitKey
(
wait_time
)
def
imshow_bboxes
(
img
,
bboxes
,
colors
=
'green'
,
top_k
=-
1
,
thickness
=
1
,
show
=
True
,
win_name
=
''
,
wait_time
=
0
,
out_file
=
None
):
"""Draw bboxes on an image.
Args:
img (str or ndarray): The image to be displayed.
bboxes (list or ndarray): A list of ndarray of shape (k, 4).
colors (list[str or tuple or Color]): A list of colors.
top_k (int): Plot the first k bboxes only if set positive.
thickness (int): Thickness of lines.
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str, optional): The filename to write the image.
Returns:
ndarray: The image with bboxes drawn on it.
"""
img
=
imread
(
img
)
img
=
np
.
ascontiguousarray
(
img
)
if
isinstance
(
bboxes
,
np
.
ndarray
):
bboxes
=
[
bboxes
]
if
not
isinstance
(
colors
,
list
):
colors
=
[
colors
for
_
in
range
(
len
(
bboxes
))]
colors
=
[
color_val
(
c
)
for
c
in
colors
]
assert
len
(
bboxes
)
==
len
(
colors
)
for
i
,
_bboxes
in
enumerate
(
bboxes
):
_bboxes
=
_bboxes
.
astype
(
np
.
int32
)
if
top_k
<=
0
:
_top_k
=
_bboxes
.
shape
[
0
]
else
:
_top_k
=
min
(
top_k
,
_bboxes
.
shape
[
0
])
for
j
in
range
(
_top_k
):
left_top
=
(
_bboxes
[
j
,
0
],
_bboxes
[
j
,
1
])
right_bottom
=
(
_bboxes
[
j
,
2
],
_bboxes
[
j
,
3
])
cv2
.
rectangle
(
img
,
left_top
,
right_bottom
,
colors
[
i
],
thickness
=
thickness
)
if
show
:
imshow
(
img
,
win_name
,
wait_time
)
if
out_file
is
not
None
:
imwrite
(
img
,
out_file
)
return
img
def
imshow_det_bboxes
(
img
,
bboxes
,
labels
,
class_names
=
None
,
score_thr
=
0
,
bbox_color
=
'green'
,
text_color
=
'green'
,
thickness
=
1
,
font_scale
=
0.5
,
show
=
True
,
win_name
=
''
,
wait_time
=
0
,
out_file
=
None
):
"""Draw bboxes and class labels (with scores) on an image.
Args:
img (str or ndarray): The image to be displayed.
bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
(n, 5).
labels (ndarray): Labels of bboxes.
class_names (list[str]): Names of each classes.
score_thr (float): Minimum score of bboxes to be shown.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str or None): The filename to write the image.
Returns:
ndarray: The image with bboxes drawn on it.
"""
assert
bboxes
.
ndim
==
2
assert
labels
.
ndim
==
1
assert
bboxes
.
shape
[
0
]
==
labels
.
shape
[
0
]
assert
bboxes
.
shape
[
1
]
==
4
or
bboxes
.
shape
[
1
]
==
5
img
=
imread
(
img
)
img
=
np
.
ascontiguousarray
(
img
)
if
score_thr
>
0
:
assert
bboxes
.
shape
[
1
]
==
5
scores
=
bboxes
[:,
-
1
]
inds
=
scores
>
score_thr
bboxes
=
bboxes
[
inds
,
:]
labels
=
labels
[
inds
]
bbox_color
=
color_val
(
bbox_color
)
text_color
=
color_val
(
text_color
)
for
bbox
,
label
in
zip
(
bboxes
,
labels
):
bbox_int
=
bbox
.
astype
(
np
.
int32
)
left_top
=
(
bbox_int
[
0
],
bbox_int
[
1
])
right_bottom
=
(
bbox_int
[
2
],
bbox_int
[
3
])
cv2
.
rectangle
(
img
,
left_top
,
right_bottom
,
bbox_color
,
thickness
=
thickness
)
label_text
=
class_names
[
label
]
if
class_names
is
not
None
else
f
'cls
{
label
}
'
if
len
(
bbox
)
>
4
:
label_text
+=
f
'|
{
bbox
[
-
1
]:.
02
f
}
'
cv2
.
putText
(
img
,
label_text
,
(
bbox_int
[
0
],
bbox_int
[
1
]
-
2
),
cv2
.
FONT_HERSHEY_COMPLEX
,
font_scale
,
text_color
)
if
show
:
imshow
(
img
,
win_name
,
wait_time
)
if
out_file
is
not
None
:
imwrite
(
img
,
out_file
)
return
img
lavis/common/annotator/uniformer/mmcv/visualization/optflow.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
__future__
import
division
import
numpy
as
np
from
annotator.uniformer.mmcv.image
import
rgb2bgr
from
annotator.uniformer.mmcv.video
import
flowread
from
.image
import
imshow
def
flowshow
(
flow
,
win_name
=
''
,
wait_time
=
0
):
"""Show optical flow.
Args:
flow (ndarray or str): The optical flow to be displayed.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
"""
flow
=
flowread
(
flow
)
flow_img
=
flow2rgb
(
flow
)
imshow
(
rgb2bgr
(
flow_img
),
win_name
,
wait_time
)
def
flow2rgb
(
flow
,
color_wheel
=
None
,
unknown_thr
=
1e6
):
"""Convert flow map to RGB image.
Args:
flow (ndarray): Array of optical flow.
color_wheel (ndarray or None): Color wheel used to map flow field to
RGB colorspace. Default color wheel will be used if not specified.
unknown_thr (str): Values above this threshold will be marked as
unknown and thus ignored.
Returns:
ndarray: RGB image that can be visualized.
"""
assert
flow
.
ndim
==
3
and
flow
.
shape
[
-
1
]
==
2
if
color_wheel
is
None
:
color_wheel
=
make_color_wheel
()
assert
color_wheel
.
ndim
==
2
and
color_wheel
.
shape
[
1
]
==
3
num_bins
=
color_wheel
.
shape
[
0
]
dx
=
flow
[:,
:,
0
].
copy
()
dy
=
flow
[:,
:,
1
].
copy
()
ignore_inds
=
(
np
.
isnan
(
dx
)
|
np
.
isnan
(
dy
)
|
(
np
.
abs
(
dx
)
>
unknown_thr
)
|
(
np
.
abs
(
dy
)
>
unknown_thr
))
dx
[
ignore_inds
]
=
0
dy
[
ignore_inds
]
=
0
rad
=
np
.
sqrt
(
dx
**
2
+
dy
**
2
)
if
np
.
any
(
rad
>
np
.
finfo
(
float
).
eps
):
max_rad
=
np
.
max
(
rad
)
dx
/=
max_rad
dy
/=
max_rad
rad
=
np
.
sqrt
(
dx
**
2
+
dy
**
2
)
angle
=
np
.
arctan2
(
-
dy
,
-
dx
)
/
np
.
pi
bin_real
=
(
angle
+
1
)
/
2
*
(
num_bins
-
1
)
bin_left
=
np
.
floor
(
bin_real
).
astype
(
int
)
bin_right
=
(
bin_left
+
1
)
%
num_bins
w
=
(
bin_real
-
bin_left
.
astype
(
np
.
float32
))[...,
None
]
flow_img
=
(
1
-
w
)
*
color_wheel
[
bin_left
,
:]
+
w
*
color_wheel
[
bin_right
,
:]
small_ind
=
rad
<=
1
flow_img
[
small_ind
]
=
1
-
rad
[
small_ind
,
None
]
*
(
1
-
flow_img
[
small_ind
])
flow_img
[
np
.
logical_not
(
small_ind
)]
*=
0.75
flow_img
[
ignore_inds
,
:]
=
0
return
flow_img
def
make_color_wheel
(
bins
=
None
):
"""Build a color wheel.
Args:
bins(list or tuple, optional): Specify the number of bins for each
color range, corresponding to six ranges: red -> yellow,
yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
(see Middlebury).
Returns:
ndarray: Color wheel of shape (total_bins, 3).
"""
if
bins
is
None
:
bins
=
[
15
,
6
,
4
,
11
,
13
,
6
]
assert
len
(
bins
)
==
6
RY
,
YG
,
GC
,
CB
,
BM
,
MR
=
tuple
(
bins
)
ry
=
[
1
,
np
.
arange
(
RY
)
/
RY
,
0
]
yg
=
[
1
-
np
.
arange
(
YG
)
/
YG
,
1
,
0
]
gc
=
[
0
,
1
,
np
.
arange
(
GC
)
/
GC
]
cb
=
[
0
,
1
-
np
.
arange
(
CB
)
/
CB
,
1
]
bm
=
[
np
.
arange
(
BM
)
/
BM
,
0
,
1
]
mr
=
[
1
,
0
,
1
-
np
.
arange
(
MR
)
/
MR
]
num_bins
=
RY
+
YG
+
GC
+
CB
+
BM
+
MR
color_wheel
=
np
.
zeros
((
3
,
num_bins
),
dtype
=
np
.
float32
)
col
=
0
for
i
,
color
in
enumerate
([
ry
,
yg
,
gc
,
cb
,
bm
,
mr
]):
for
j
in
range
(
3
):
color_wheel
[
j
,
col
:
col
+
bins
[
i
]]
=
color
[
j
]
col
+=
bins
[
i
]
return
color_wheel
.
T
lavis/common/annotator/uniformer/mmcv_custom/__init__.py
0 → 100644
View file @
c04f261a
# -*- coding: utf-8 -*-
from
.checkpoint
import
load_checkpoint
__all__
=
[
'load_checkpoint'
]
\ No newline at end of file
lavis/common/annotator/uniformer/mmcv_custom/checkpoint.py
0 → 100644
View file @
c04f261a
# Copyright (c) Open-MMLab. All rights reserved.
import
io
import
os
import
os.path
as
osp
import
pkgutil
import
time
import
warnings
from
collections
import
OrderedDict
from
importlib
import
import_module
from
tempfile
import
TemporaryDirectory
import
torch
import
torchvision
from
torch.optim
import
Optimizer
from
torch.utils
import
model_zoo
from
torch.nn
import
functional
as
F
import
annotator.uniformer.mmcv
as
mmcv
from
annotator.uniformer.mmcv.fileio
import
FileClient
from
annotator.uniformer.mmcv.fileio
import
load
as
load_file
from
annotator.uniformer.mmcv.parallel
import
is_module_wrapper
from
annotator.uniformer.mmcv.utils
import
mkdir_or_exist
from
annotator.uniformer.mmcv.runner
import
get_dist_info
ENV_MMCV_HOME
=
'MMCV_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
def
_get_mmcv_home
():
mmcv_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_MMCV_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'mmcv'
)))
mkdir_or_exist
(
mmcv_home
)
return
mmcv_home
def
load_state_dict
(
module
,
state_dict
,
strict
=
False
,
logger
=
None
):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys
=
[]
all_missing_keys
=
[]
err_msg
=
[]
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
# use _load_from_state_dict to enable checkpoint version control
def
load
(
module
,
prefix
=
''
):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if
is_module_wrapper
(
module
):
module
=
module
.
module
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
all_missing_keys
,
unexpected_keys
,
err_msg
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
module
)
load
=
None
# break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys
=
[
key
for
key
in
all_missing_keys
if
'num_batches_tracked'
not
in
key
]
if
unexpected_keys
:
err_msg
.
append
(
'unexpected key in source '
f
'state_dict:
{
", "
.
join
(
unexpected_keys
)
}
\n
'
)
if
missing_keys
:
err_msg
.
append
(
f
'missing keys in source state_dict:
{
", "
.
join
(
missing_keys
)
}
\n
'
)
rank
,
_
=
get_dist_info
()
if
len
(
err_msg
)
>
0
and
rank
==
0
:
err_msg
.
insert
(
0
,
'The model and loaded state dict do not match exactly
\n
'
)
err_msg
=
'
\n
'
.
join
(
err_msg
)
if
strict
:
raise
RuntimeError
(
err_msg
)
elif
logger
is
not
None
:
logger
.
warning
(
err_msg
)
else
:
print
(
err_msg
)
def
load_url_dist
(
url
,
model_dir
=
None
):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank
,
world_size
=
get_dist_info
()
rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
rank
))
if
rank
==
0
:
checkpoint
=
model_zoo
.
load_url
(
url
,
model_dir
=
model_dir
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
>
0
:
checkpoint
=
model_zoo
.
load_url
(
url
,
model_dir
=
model_dir
)
return
checkpoint
def
load_pavimodel_dist
(
model_path
,
map_location
=
None
):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try
:
from
pavi
import
modelcloud
except
ImportError
:
raise
ImportError
(
'Please install pavi to load checkpoint from modelcloud.'
)
rank
,
world_size
=
get_dist_info
()
rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
rank
))
if
rank
==
0
:
model
=
modelcloud
.
get
(
model_path
)
with
TemporaryDirectory
()
as
tmp_dir
:
downloaded_file
=
osp
.
join
(
tmp_dir
,
model
.
name
)
model
.
download
(
downloaded_file
)
checkpoint
=
torch
.
load
(
downloaded_file
,
map_location
=
map_location
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
>
0
:
model
=
modelcloud
.
get
(
model_path
)
with
TemporaryDirectory
()
as
tmp_dir
:
downloaded_file
=
osp
.
join
(
tmp_dir
,
model
.
name
)
model
.
download
(
downloaded_file
)
checkpoint
=
torch
.
load
(
downloaded_file
,
map_location
=
map_location
)
return
checkpoint
def
load_fileclient_dist
(
filename
,
backend
,
map_location
):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank
,
world_size
=
get_dist_info
()
rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
rank
))
allowed_backends
=
[
'ceph'
]
if
backend
not
in
allowed_backends
:
raise
ValueError
(
f
'Load from Backend
{
backend
}
is not supported.'
)
if
rank
==
0
:
fileclient
=
FileClient
(
backend
=
backend
)
buffer
=
io
.
BytesIO
(
fileclient
.
get
(
filename
))
checkpoint
=
torch
.
load
(
buffer
,
map_location
=
map_location
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
>
0
:
fileclient
=
FileClient
(
backend
=
backend
)
buffer
=
io
.
BytesIO
(
fileclient
.
get
(
filename
))
checkpoint
=
torch
.
load
(
buffer
,
map_location
=
map_location
)
return
checkpoint
def
get_torchvision_models
():
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
continue
_zoo
=
import_module
(
f
'torchvision.models.
{
name
}
'
)
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
return
model_urls
def
get_external_models
():
mmcv_home
=
_get_mmcv_home
()
default_json_path
=
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/open_mmlab.json'
)
default_urls
=
load_file
(
default_json_path
)
assert
isinstance
(
default_urls
,
dict
)
external_json_path
=
osp
.
join
(
mmcv_home
,
'open_mmlab.json'
)
if
osp
.
exists
(
external_json_path
):
external_urls
=
load_file
(
external_json_path
)
assert
isinstance
(
external_urls
,
dict
)
default_urls
.
update
(
external_urls
)
return
default_urls
def
get_mmcls_models
():
mmcls_json_path
=
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/mmcls.json'
)
mmcls_urls
=
load_file
(
mmcls_json_path
)
return
mmcls_urls
def
get_deprecated_model_names
():
deprecate_json_path
=
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/deprecated.json'
)
deprecate_urls
=
load_file
(
deprecate_json_path
)
assert
isinstance
(
deprecate_urls
,
dict
)
return
deprecate_urls
def
_process_mmcls_checkpoint
(
checkpoint
):
state_dict
=
checkpoint
[
'state_dict'
]
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
if
k
.
startswith
(
'backbone.'
):
new_state_dict
[
k
[
9
:]]
=
v
new_checkpoint
=
dict
(
state_dict
=
new_state_dict
)
return
new_checkpoint
def
_load_checkpoint
(
filename
,
map_location
=
None
):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if
filename
.
startswith
(
'modelzoo://'
):
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
model_urls
=
get_torchvision_models
()
model_name
=
filename
[
11
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'torchvision://'
):
model_urls
=
get_torchvision_models
()
model_name
=
filename
[
14
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'open-mmlab://'
):
model_urls
=
get_external_models
()
model_name
=
filename
[
13
:]
deprecated_urls
=
get_deprecated_model_names
()
if
model_name
in
deprecated_urls
:
warnings
.
warn
(
f
'open-mmlab://
{
model_name
}
is deprecated in favor '
f
'of open-mmlab://
{
deprecated_urls
[
model_name
]
}
'
)
model_name
=
deprecated_urls
[
model_name
]
model_url
=
model_urls
[
model_name
]
# check if is url
if
model_url
.
startswith
((
'http://'
,
'https://'
)):
checkpoint
=
load_url_dist
(
model_url
)
else
:
filename
=
osp
.
join
(
_get_mmcv_home
(),
model_url
)
if
not
osp
.
isfile
(
filename
):
raise
IOError
(
f
'
{
filename
}
is not a checkpoint file'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
elif
filename
.
startswith
(
'mmcls://'
):
model_urls
=
get_mmcls_models
()
model_name
=
filename
[
8
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
checkpoint
=
_process_mmcls_checkpoint
(
checkpoint
)
elif
filename
.
startswith
((
'http://'
,
'https://'
)):
checkpoint
=
load_url_dist
(
filename
)
elif
filename
.
startswith
(
'pavi://'
):
model_path
=
filename
[
7
:]
checkpoint
=
load_pavimodel_dist
(
model_path
,
map_location
=
map_location
)
elif
filename
.
startswith
(
's3://'
):
checkpoint
=
load_fileclient_dist
(
filename
,
backend
=
'ceph'
,
map_location
=
map_location
)
else
:
if
not
osp
.
isfile
(
filename
):
raise
IOError
(
f
'
{
filename
}
is not a checkpoint file'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
def
load_checkpoint
(
model
,
filename
,
map_location
=
'cpu'
,
strict
=
False
,
logger
=
None
):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint
=
_load_checkpoint
(
filename
,
map_location
)
# OrderedDict is a subclass of dict
if
not
isinstance
(
checkpoint
,
dict
):
raise
RuntimeError
(
f
'No state_dict found in checkpoint file
{
filename
}
'
)
# get state_dict from checkpoint
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
elif
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
# strip prefix of state_dict
if
list
(
state_dict
.
keys
())[
0
].
startswith
(
'module.'
):
state_dict
=
{
k
[
7
:]:
v
for
k
,
v
in
state_dict
.
items
()}
# for MoBY, load model of online branch
if
sorted
(
list
(
state_dict
.
keys
()))[
0
].
startswith
(
'encoder'
):
state_dict
=
{
k
.
replace
(
'encoder.'
,
''
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
'encoder.'
)}
# reshape absolute position embedding
if
state_dict
.
get
(
'absolute_pos_embed'
)
is
not
None
:
absolute_pos_embed
=
state_dict
[
'absolute_pos_embed'
]
N1
,
L
,
C1
=
absolute_pos_embed
.
size
()
N2
,
C2
,
H
,
W
=
model
.
absolute_pos_embed
.
size
()
if
N1
!=
N2
or
C1
!=
C2
or
L
!=
H
*
W
:
logger
.
warning
(
"Error in loading absolute_pos_embed, pass"
)
else
:
state_dict
[
'absolute_pos_embed'
]
=
absolute_pos_embed
.
view
(
N2
,
H
,
W
,
C2
).
permute
(
0
,
3
,
1
,
2
)
# interpolate position bias table if needed
relative_position_bias_table_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
"relative_position_bias_table"
in
k
]
for
table_key
in
relative_position_bias_table_keys
:
table_pretrained
=
state_dict
[
table_key
]
table_current
=
model
.
state_dict
()[
table_key
]
L1
,
nH1
=
table_pretrained
.
size
()
L2
,
nH2
=
table_current
.
size
()
if
nH1
!=
nH2
:
logger
.
warning
(
f
"Error in loading
{
table_key
}
, pass"
)
else
:
if
L1
!=
L2
:
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
table_pretrained_resized
=
F
.
interpolate
(
table_pretrained
.
permute
(
1
,
0
).
view
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
state_dict
[
table_key
]
=
table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
)
# load state_dict
load_state_dict
(
model
,
state_dict
,
strict
,
logger
)
return
checkpoint
def
weights_to_cpu
(
state_dict
):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu
=
OrderedDict
()
for
key
,
val
in
state_dict
.
items
():
state_dict_cpu
[
key
]
=
val
.
cpu
()
return
state_dict_cpu
def
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for
name
,
param
in
module
.
_parameters
.
items
():
if
param
is
not
None
:
destination
[
prefix
+
name
]
=
param
if
keep_vars
else
param
.
detach
()
for
name
,
buf
in
module
.
_buffers
.
items
():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if
buf
is
not
None
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
def
get_state_dict
(
module
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if
is_module_wrapper
(
module
):
module
=
module
.
module
# below is the same as torch.nn.Module.state_dict()
if
destination
is
None
:
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
version
=
module
.
_version
)
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
get_state_dict
(
child
,
destination
,
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
for
hook
in
module
.
_state_dict_hooks
.
values
():
hook_result
=
hook
(
module
,
destination
,
prefix
,
local_metadata
)
if
hook_result
is
not
None
:
destination
=
hook_result
return
destination
def
save_checkpoint
(
model
,
filename
,
optimizer
=
None
,
meta
=
None
):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if
meta
is
None
:
meta
=
{}
elif
not
isinstance
(
meta
,
dict
):
raise
TypeError
(
f
'meta must be a dict or None, but got
{
type
(
meta
)
}
'
)
meta
.
update
(
mmcv_version
=
mmcv
.
__version__
,
time
=
time
.
asctime
())
if
is_module_wrapper
(
model
):
model
=
model
.
module
if
hasattr
(
model
,
'CLASSES'
)
and
model
.
CLASSES
is
not
None
:
# save class name to the meta
meta
.
update
(
CLASSES
=
model
.
CLASSES
)
checkpoint
=
{
'meta'
:
meta
,
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
}
# save optimizer state dict in the checkpoint
if
isinstance
(
optimizer
,
Optimizer
):
checkpoint
[
'optimizer'
]
=
optimizer
.
state_dict
()
elif
isinstance
(
optimizer
,
dict
):
checkpoint
[
'optimizer'
]
=
{}
for
name
,
optim
in
optimizer
.
items
():
checkpoint
[
'optimizer'
][
name
]
=
optim
.
state_dict
()
if
filename
.
startswith
(
'pavi://'
):
try
:
from
pavi
import
modelcloud
from
pavi.exception
import
NodeNotFoundError
except
ImportError
:
raise
ImportError
(
'Please install pavi to load checkpoint from modelcloud.'
)
model_path
=
filename
[
7
:]
root
=
modelcloud
.
Folder
()
model_dir
,
model_name
=
osp
.
split
(
model_path
)
try
:
model
=
modelcloud
.
get
(
model_dir
)
except
NodeNotFoundError
:
model
=
root
.
create_training_model
(
model_dir
)
with
TemporaryDirectory
()
as
tmp_dir
:
checkpoint_file
=
osp
.
join
(
tmp_dir
,
model_name
)
with
open
(
checkpoint_file
,
'wb'
)
as
f
:
torch
.
save
(
checkpoint
,
f
)
f
.
flush
()
model
.
create_file
(
checkpoint_file
,
name
=
model_name
)
else
:
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filename
))
# immediately flush buffer
with
open
(
filename
,
'wb'
)
as
f
:
torch
.
save
(
checkpoint
,
f
)
f
.
flush
()
\ No newline at end of file
lavis/common/annotator/uniformer/mmseg/apis/__init__.py
0 → 100644
View file @
c04f261a
from
.inference
import
inference_segmentor
,
init_segmentor
,
show_result_pyplot
from
.test
import
multi_gpu_test
,
single_gpu_test
from
.train
import
get_root_logger
,
set_random_seed
,
train_segmentor
__all__
=
[
'get_root_logger'
,
'set_random_seed'
,
'train_segmentor'
,
'init_segmentor'
,
'inference_segmentor'
,
'multi_gpu_test'
,
'single_gpu_test'
,
'show_result_pyplot'
]
lavis/common/annotator/uniformer/mmseg/apis/inference.py
0 → 100644
View file @
c04f261a
import
matplotlib.pyplot
as
plt
import
annotator.uniformer.mmcv
as
mmcv
import
torch
from
annotator.uniformer.mmcv.parallel
import
collate
,
scatter
from
annotator.uniformer.mmcv.runner
import
load_checkpoint
from
annotator.uniformer.mmseg.datasets.pipelines
import
Compose
from
annotator.uniformer.mmseg.models
import
build_segmentor
def
init_segmentor
(
config
,
checkpoint
=
None
,
device
=
'cuda:0'
):
"""Initialize a segmentor from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
Use 'cpu' for loading model on CPU.
Returns:
nn.Module: The constructed segmentor.
"""
if
isinstance
(
config
,
str
):
config
=
mmcv
.
Config
.
fromfile
(
config
)
elif
not
isinstance
(
config
,
mmcv
.
Config
):
raise
TypeError
(
'config must be a filename or Config object, '
'but got {}'
.
format
(
type
(
config
)))
config
.
model
.
pretrained
=
None
config
.
model
.
train_cfg
=
None
model
=
build_segmentor
(
config
.
model
,
test_cfg
=
config
.
get
(
'test_cfg'
))
if
checkpoint
is
not
None
:
checkpoint
=
load_checkpoint
(
model
,
checkpoint
,
map_location
=
'cpu'
)
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
model
.
PALETTE
=
checkpoint
[
'meta'
][
'PALETTE'
]
model
.
cfg
=
config
# save the config in the model for convenience
model
.
to
(
device
)
model
.
eval
()
return
model
class
LoadImage
:
"""A simple pipeline to load image."""
def
__call__
(
self
,
results
):
"""Call function to load images into results.
Args:
results (dict): A result dict contains the file name
of the image to be read.
Returns:
dict: ``results`` will be returned containing loaded image.
"""
if
isinstance
(
results
[
'img'
],
str
):
results
[
'filename'
]
=
results
[
'img'
]
results
[
'ori_filename'
]
=
results
[
'img'
]
else
:
results
[
'filename'
]
=
None
results
[
'ori_filename'
]
=
None
img
=
mmcv
.
imread
(
results
[
'img'
])
results
[
'img'
]
=
img
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
return
results
def
inference_segmentor
(
model
,
img
):
"""Inference image(s) with the segmentor.
Args:
model (nn.Module): The loaded segmentor.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
(list[Tensor]): The segmentation result.
"""
cfg
=
model
.
cfg
device
=
next
(
model
.
parameters
()).
device
# model device
# build the data pipeline
test_pipeline
=
[
LoadImage
()]
+
cfg
.
data
.
test
.
pipeline
[
1
:]
test_pipeline
=
Compose
(
test_pipeline
)
# prepare data
data
=
dict
(
img
=
img
)
data
=
test_pipeline
(
data
)
data
=
collate
([
data
],
samples_per_gpu
=
1
)
if
next
(
model
.
parameters
()).
is_cuda
:
# scatter to specified GPU
data
=
scatter
(
data
,
[
device
])[
0
]
else
:
data
[
'img_metas'
]
=
[
i
.
data
[
0
]
for
i
in
data
[
'img_metas'
]]
# forward the model
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
return
result
def
show_result_pyplot
(
model
,
img
,
result
,
palette
=
None
,
fig_size
=
(
15
,
10
),
opacity
=
0.5
,
title
=
''
,
block
=
True
):
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is True.
"""
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
img
=
model
.
show_result
(
img
,
result
,
palette
=
palette
,
show
=
False
,
opacity
=
opacity
)
# plt.figure(figsize=fig_size)
# plt.imshow(mmcv.bgr2rgb(img))
# plt.title(title)
# plt.tight_layout()
# plt.show(block=block)
return
mmcv
.
bgr2rgb
(
img
)
lavis/common/annotator/uniformer/mmseg/apis/test.py
0 → 100644
View file @
c04f261a
import
os.path
as
osp
import
pickle
import
shutil
import
tempfile
import
annotator.uniformer.mmcv
as
mmcv
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
annotator.uniformer.mmcv.image
import
tensor2imgs
from
annotator.uniformer.mmcv.runner
import
get_dist_info
def
np2tmp
(
array
,
temp_file_name
=
None
):
"""Save ndarray to local numpy file.
Args:
array (ndarray): Ndarray to save.
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
Returns:
str: The numpy file name.
"""
if
temp_file_name
is
None
:
temp_file_name
=
tempfile
.
NamedTemporaryFile
(
suffix
=
'.npy'
,
delete
=
False
).
name
np
.
save
(
temp_file_name
,
array
)
return
temp_file_name
def
single_gpu_test
(
model
,
data_loader
,
show
=
False
,
out_dir
=
None
,
efficient_test
=
False
,
opacity
=
0.5
):
"""Test with single GPU.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
show (bool): Whether show results during inference. Default: False.
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
**
data
)
if
show
or
out_dir
:
img_tensor
=
data
[
'img'
][
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
imgs
=
tensor2imgs
(
img_tensor
,
**
img_metas
[
0
][
'img_norm_cfg'
])
assert
len
(
imgs
)
==
len
(
img_metas
)
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
img_show
=
img
[:
h
,
:
w
,
:]
ori_h
,
ori_w
=
img_meta
[
'ori_shape'
][:
-
1
]
img_show
=
mmcv
.
imresize
(
img_show
,
(
ori_w
,
ori_h
))
if
out_dir
:
out_file
=
osp
.
join
(
out_dir
,
img_meta
[
'ori_filename'
])
else
:
out_file
=
None
model
.
module
.
show_result
(
img_show
,
result
,
palette
=
dataset
.
PALETTE
,
show
=
show
,
out_file
=
out_file
,
opacity
=
opacity
)
if
isinstance
(
result
,
list
):
if
efficient_test
:
result
=
[
np2tmp
(
_
)
for
_
in
result
]
results
.
extend
(
result
)
else
:
if
efficient_test
:
result
=
np2tmp
(
result
)
results
.
append
(
result
)
batch_size
=
len
(
result
)
for
_
in
range
(
batch_size
):
prog_bar
.
update
()
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
,
efficient_test
=
False
):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
if
isinstance
(
result
,
list
):
if
efficient_test
:
result
=
[
np2tmp
(
_
)
for
_
in
result
]
results
.
extend
(
result
)
else
:
if
efficient_test
:
result
=
np2tmp
(
result
)
results
.
append
(
result
)
if
rank
==
0
:
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
for
_
in
range
(
batch_size
*
world_size
):
prog_bar
.
update
()
# collect results from all ranks
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
"""Collect results with CPU."""
rank
,
world_size
=
get_dist_info
()
# create a tmp dir if it is not specified
if
tmpdir
is
None
:
MAX_LEN
=
512
# 32 is whitespace
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
32
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
if
rank
==
0
:
tmpdir
=
tempfile
.
mkdtemp
()
tmpdir
=
torch
.
tensor
(
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dist
.
broadcast
(
dir_tensor
,
0
)
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
'part_{}.pkl'
.
format
(
rank
)))
dist
.
barrier
()
# collect all parts
if
rank
!=
0
:
return
None
else
:
# load results of all parts from tmp dir
part_list
=
[]
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
'part_{}.pkl'
.
format
(
i
))
part_list
.
append
(
mmcv
.
load
(
part_file
))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
):
"""Collect results with GPU."""
rank
,
world_size
=
get_dist_info
()
# dump result part to tensor with pickle
part_tensor
=
torch
.
tensor
(
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# gather all result part tensor shape
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
shape_list
,
shape_tensor
)
# padding result part tensor to max length
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_recv_list
=
[
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
]
# gather all result part
dist
.
all_gather
(
part_recv_list
,
part_send
)
if
rank
==
0
:
part_list
=
[]
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
part_list
.
append
(
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
()))
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
lavis/common/annotator/uniformer/mmseg/apis/train.py
0 → 100644
View file @
c04f261a
import
random
import
warnings
import
numpy
as
np
import
torch
from
annotator.uniformer.mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
annotator.uniformer.mmcv.runner
import
build_optimizer
,
build_runner
from
annotator.uniformer.mmseg.core
import
DistEvalHook
,
EvalHook
from
annotator.uniformer.mmseg.datasets
import
build_dataloader
,
build_dataset
from
annotator.uniformer.mmseg.utils
import
get_root_logger
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
train_segmentor
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
"""Launch segmentor training."""
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
data_loaders
=
[
build_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
,
drop_last
=
True
)
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
cfg
.
get
(
'runner'
)
is
None
:
cfg
.
runner
=
{
'type'
:
'IterBasedRunner'
,
'max_iters'
:
cfg
.
total_iters
}
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
runner
=
build_runner
(
cfg
.
runner
,
default_args
=
dict
(
model
=
model
,
batch_processor
=
None
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
))
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
cfg
.
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
# an ugly walkaround to make the .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# register eval hooks
if
validate
:
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_dataloader
(
val_dataset
,
samples_per_gpu
=
1
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
),
priority
=
'LOW'
)
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
lavis/common/annotator/uniformer/mmseg/core/__init__.py
0 → 100644
View file @
c04f261a
from
.evaluation
import
*
# noqa: F401, F403
from
.seg
import
*
# noqa: F401, F403
from
.utils
import
*
# noqa: F401, F403
lavis/common/annotator/uniformer/mmseg/core/evaluation/__init__.py
0 → 100644
View file @
c04f261a
from
.class_names
import
get_classes
,
get_palette
from
.eval_hooks
import
DistEvalHook
,
EvalHook
from
.metrics
import
eval_metrics
,
mean_dice
,
mean_fscore
,
mean_iou
__all__
=
[
'EvalHook'
,
'DistEvalHook'
,
'mean_dice'
,
'mean_iou'
,
'mean_fscore'
,
'eval_metrics'
,
'get_classes'
,
'get_palette'
]
lavis/common/annotator/uniformer/mmseg/core/evaluation/class_names.py
0 → 100644
View file @
c04f261a
import
annotator.uniformer.mmcv
as
mmcv
def
cityscapes_classes
():
"""Cityscapes class names for external use."""
return
[
'road'
,
'sidewalk'
,
'building'
,
'wall'
,
'fence'
,
'pole'
,
'traffic light'
,
'traffic sign'
,
'vegetation'
,
'terrain'
,
'sky'
,
'person'
,
'rider'
,
'car'
,
'truck'
,
'bus'
,
'train'
,
'motorcycle'
,
'bicycle'
]
def
ade_classes
():
"""ADE20K class names for external use."""
return
[
'wall'
,
'building'
,
'sky'
,
'floor'
,
'tree'
,
'ceiling'
,
'road'
,
'bed '
,
'windowpane'
,
'grass'
,
'cabinet'
,
'sidewalk'
,
'person'
,
'earth'
,
'door'
,
'table'
,
'mountain'
,
'plant'
,
'curtain'
,
'chair'
,
'car'
,
'water'
,
'painting'
,
'sofa'
,
'shelf'
,
'house'
,
'sea'
,
'mirror'
,
'rug'
,
'field'
,
'armchair'
,
'seat'
,
'fence'
,
'desk'
,
'rock'
,
'wardrobe'
,
'lamp'
,
'bathtub'
,
'railing'
,
'cushion'
,
'base'
,
'box'
,
'column'
,
'signboard'
,
'chest of drawers'
,
'counter'
,
'sand'
,
'sink'
,
'skyscraper'
,
'fireplace'
,
'refrigerator'
,
'grandstand'
,
'path'
,
'stairs'
,
'runway'
,
'case'
,
'pool table'
,
'pillow'
,
'screen door'
,
'stairway'
,
'river'
,
'bridge'
,
'bookcase'
,
'blind'
,
'coffee table'
,
'toilet'
,
'flower'
,
'book'
,
'hill'
,
'bench'
,
'countertop'
,
'stove'
,
'palm'
,
'kitchen island'
,
'computer'
,
'swivel chair'
,
'boat'
,
'bar'
,
'arcade machine'
,
'hovel'
,
'bus'
,
'towel'
,
'light'
,
'truck'
,
'tower'
,
'chandelier'
,
'awning'
,
'streetlight'
,
'booth'
,
'television receiver'
,
'airplane'
,
'dirt track'
,
'apparel'
,
'pole'
,
'land'
,
'bannister'
,
'escalator'
,
'ottoman'
,
'bottle'
,
'buffet'
,
'poster'
,
'stage'
,
'van'
,
'ship'
,
'fountain'
,
'conveyer belt'
,
'canopy'
,
'washer'
,
'plaything'
,
'swimming pool'
,
'stool'
,
'barrel'
,
'basket'
,
'waterfall'
,
'tent'
,
'bag'
,
'minibike'
,
'cradle'
,
'oven'
,
'ball'
,
'food'
,
'step'
,
'tank'
,
'trade name'
,
'microwave'
,
'pot'
,
'animal'
,
'bicycle'
,
'lake'
,
'dishwasher'
,
'screen'
,
'blanket'
,
'sculpture'
,
'hood'
,
'sconce'
,
'vase'
,
'traffic light'
,
'tray'
,
'ashcan'
,
'fan'
,
'pier'
,
'crt screen'
,
'plate'
,
'monitor'
,
'bulletin board'
,
'shower'
,
'radiator'
,
'glass'
,
'clock'
,
'flag'
]
def
voc_classes
():
"""Pascal VOC class names for external use."""
return
[
'background'
,
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
]
def
cityscapes_palette
():
"""Cityscapes palette for external use."""
return
[[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
]]
def
ade_palette
():
"""ADE20K palette for external use."""
return
[[
120
,
120
,
120
],
[
180
,
120
,
120
],
[
6
,
230
,
230
],
[
80
,
50
,
50
],
[
4
,
200
,
3
],
[
120
,
120
,
80
],
[
140
,
140
,
140
],
[
204
,
5
,
255
],
[
230
,
230
,
230
],
[
4
,
250
,
7
],
[
224
,
5
,
255
],
[
235
,
255
,
7
],
[
150
,
5
,
61
],
[
120
,
120
,
70
],
[
8
,
255
,
51
],
[
255
,
6
,
82
],
[
143
,
255
,
140
],
[
204
,
255
,
4
],
[
255
,
51
,
7
],
[
204
,
70
,
3
],
[
0
,
102
,
200
],
[
61
,
230
,
250
],
[
255
,
6
,
51
],
[
11
,
102
,
255
],
[
255
,
7
,
71
],
[
255
,
9
,
224
],
[
9
,
7
,
230
],
[
220
,
220
,
220
],
[
255
,
9
,
92
],
[
112
,
9
,
255
],
[
8
,
255
,
214
],
[
7
,
255
,
224
],
[
255
,
184
,
6
],
[
10
,
255
,
71
],
[
255
,
41
,
10
],
[
7
,
255
,
255
],
[
224
,
255
,
8
],
[
102
,
8
,
255
],
[
255
,
61
,
6
],
[
255
,
194
,
7
],
[
255
,
122
,
8
],
[
0
,
255
,
20
],
[
255
,
8
,
41
],
[
255
,
5
,
153
],
[
6
,
51
,
255
],
[
235
,
12
,
255
],
[
160
,
150
,
20
],
[
0
,
163
,
255
],
[
140
,
140
,
140
],
[
250
,
10
,
15
],
[
20
,
255
,
0
],
[
31
,
255
,
0
],
[
255
,
31
,
0
],
[
255
,
224
,
0
],
[
153
,
255
,
0
],
[
0
,
0
,
255
],
[
255
,
71
,
0
],
[
0
,
235
,
255
],
[
0
,
173
,
255
],
[
31
,
0
,
255
],
[
11
,
200
,
200
],
[
255
,
82
,
0
],
[
0
,
255
,
245
],
[
0
,
61
,
255
],
[
0
,
255
,
112
],
[
0
,
255
,
133
],
[
255
,
0
,
0
],
[
255
,
163
,
0
],
[
255
,
102
,
0
],
[
194
,
255
,
0
],
[
0
,
143
,
255
],
[
51
,
255
,
0
],
[
0
,
82
,
255
],
[
0
,
255
,
41
],
[
0
,
255
,
173
],
[
10
,
0
,
255
],
[
173
,
255
,
0
],
[
0
,
255
,
153
],
[
255
,
92
,
0
],
[
255
,
0
,
255
],
[
255
,
0
,
245
],
[
255
,
0
,
102
],
[
255
,
173
,
0
],
[
255
,
0
,
20
],
[
255
,
184
,
184
],
[
0
,
31
,
255
],
[
0
,
255
,
61
],
[
0
,
71
,
255
],
[
255
,
0
,
204
],
[
0
,
255
,
194
],
[
0
,
255
,
82
],
[
0
,
10
,
255
],
[
0
,
112
,
255
],
[
51
,
0
,
255
],
[
0
,
194
,
255
],
[
0
,
122
,
255
],
[
0
,
255
,
163
],
[
255
,
153
,
0
],
[
0
,
255
,
10
],
[
255
,
112
,
0
],
[
143
,
255
,
0
],
[
82
,
0
,
255
],
[
163
,
255
,
0
],
[
255
,
235
,
0
],
[
8
,
184
,
170
],
[
133
,
0
,
255
],
[
0
,
255
,
92
],
[
184
,
0
,
255
],
[
255
,
0
,
31
],
[
0
,
184
,
255
],
[
0
,
214
,
255
],
[
255
,
0
,
112
],
[
92
,
255
,
0
],
[
0
,
224
,
255
],
[
112
,
224
,
255
],
[
70
,
184
,
160
],
[
163
,
0
,
255
],
[
153
,
0
,
255
],
[
71
,
255
,
0
],
[
255
,
0
,
163
],
[
255
,
204
,
0
],
[
255
,
0
,
143
],
[
0
,
255
,
235
],
[
133
,
255
,
0
],
[
255
,
0
,
235
],
[
245
,
0
,
255
],
[
255
,
0
,
122
],
[
255
,
245
,
0
],
[
10
,
190
,
212
],
[
214
,
255
,
0
],
[
0
,
204
,
255
],
[
20
,
0
,
255
],
[
255
,
255
,
0
],
[
0
,
153
,
255
],
[
0
,
41
,
255
],
[
0
,
255
,
204
],
[
41
,
0
,
255
],
[
41
,
255
,
0
],
[
173
,
0
,
255
],
[
0
,
245
,
255
],
[
71
,
0
,
255
],
[
122
,
0
,
255
],
[
0
,
255
,
184
],
[
0
,
92
,
255
],
[
184
,
255
,
0
],
[
0
,
133
,
255
],
[
255
,
214
,
0
],
[
25
,
194
,
194
],
[
102
,
255
,
0
],
[
92
,
0
,
255
]]
def
voc_palette
():
"""Pascal VOC palette for external use."""
return
[[
0
,
0
,
0
],
[
128
,
0
,
0
],
[
0
,
128
,
0
],
[
128
,
128
,
0
],
[
0
,
0
,
128
],
[
128
,
0
,
128
],
[
0
,
128
,
128
],
[
128
,
128
,
128
],
[
64
,
0
,
0
],
[
192
,
0
,
0
],
[
64
,
128
,
0
],
[
192
,
128
,
0
],
[
64
,
0
,
128
],
[
192
,
0
,
128
],
[
64
,
128
,
128
],
[
192
,
128
,
128
],
[
0
,
64
,
0
],
[
128
,
64
,
0
],
[
0
,
192
,
0
],
[
128
,
192
,
0
],
[
0
,
64
,
128
]]
dataset_aliases
=
{
'cityscapes'
:
[
'cityscapes'
],
'ade'
:
[
'ade'
,
'ade20k'
],
'voc'
:
[
'voc'
,
'pascal_voc'
,
'voc12'
,
'voc12aug'
]
}
def
get_classes
(
dataset
):
"""Get class names of a dataset."""
alias2name
=
{}
for
name
,
aliases
in
dataset_aliases
.
items
():
for
alias
in
aliases
:
alias2name
[
alias
]
=
name
if
mmcv
.
is_str
(
dataset
):
if
dataset
in
alias2name
:
labels
=
eval
(
alias2name
[
dataset
]
+
'_classes()'
)
else
:
raise
ValueError
(
f
'Unrecognized dataset:
{
dataset
}
'
)
else
:
raise
TypeError
(
f
'dataset must a str, but got
{
type
(
dataset
)
}
'
)
return
labels
def
get_palette
(
dataset
):
"""Get class palette (RGB) of a dataset."""
alias2name
=
{}
for
name
,
aliases
in
dataset_aliases
.
items
():
for
alias
in
aliases
:
alias2name
[
alias
]
=
name
if
mmcv
.
is_str
(
dataset
):
if
dataset
in
alias2name
:
labels
=
eval
(
alias2name
[
dataset
]
+
'_palette()'
)
else
:
raise
ValueError
(
f
'Unrecognized dataset:
{
dataset
}
'
)
else
:
raise
TypeError
(
f
'dataset must a str, but got
{
type
(
dataset
)
}
'
)
return
labels
lavis/common/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
0 → 100644
View file @
c04f261a
import
os.path
as
osp
from
annotator.uniformer.mmcv.runner
import
DistEvalHook
as
_DistEvalHook
from
annotator.uniformer.mmcv.runner
import
EvalHook
as
_EvalHook
class
EvalHook
(
_EvalHook
):
"""Single GPU EvalHook, with efficient test support.
Args:
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
"""
greater_keys
=
[
'mIoU'
,
'mAcc'
,
'aAcc'
]
def
__init__
(
self
,
*
args
,
by_epoch
=
False
,
efficient_test
=
False
,
**
kwargs
):
super
().
__init__
(
*
args
,
by_epoch
=
by_epoch
,
**
kwargs
)
self
.
efficient_test
=
efficient_test
def
after_train_iter
(
self
,
runner
):
"""After train epoch hook.
Override default ``single_gpu_test``.
"""
if
self
.
by_epoch
or
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
from
annotator.uniformer.mmseg.apis
import
single_gpu_test
runner
.
log_buffer
.
clear
()
results
=
single_gpu_test
(
runner
.
model
,
self
.
dataloader
,
show
=
False
,
efficient_test
=
self
.
efficient_test
)
self
.
evaluate
(
runner
,
results
)
def
after_train_epoch
(
self
,
runner
):
"""After train epoch hook.
Override default ``single_gpu_test``.
"""
if
not
self
.
by_epoch
or
not
self
.
every_n_epochs
(
runner
,
self
.
interval
):
return
from
annotator.uniformer.mmseg.apis
import
single_gpu_test
runner
.
log_buffer
.
clear
()
results
=
single_gpu_test
(
runner
.
model
,
self
.
dataloader
,
show
=
False
)
self
.
evaluate
(
runner
,
results
)
class
DistEvalHook
(
_DistEvalHook
):
"""Distributed EvalHook, with efficient test support.
Args:
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
"""
greater_keys
=
[
'mIoU'
,
'mAcc'
,
'aAcc'
]
def
__init__
(
self
,
*
args
,
by_epoch
=
False
,
efficient_test
=
False
,
**
kwargs
):
super
().
__init__
(
*
args
,
by_epoch
=
by_epoch
,
**
kwargs
)
self
.
efficient_test
=
efficient_test
def
after_train_iter
(
self
,
runner
):
"""After train epoch hook.
Override default ``multi_gpu_test``.
"""
if
self
.
by_epoch
or
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
from
annotator.uniformer.mmseg.apis
import
multi_gpu_test
runner
.
log_buffer
.
clear
()
results
=
multi_gpu_test
(
runner
.
model
,
self
.
dataloader
,
tmpdir
=
osp
.
join
(
runner
.
work_dir
,
'.eval_hook'
),
gpu_collect
=
self
.
gpu_collect
,
efficient_test
=
self
.
efficient_test
)
if
runner
.
rank
==
0
:
print
(
'
\n
'
)
self
.
evaluate
(
runner
,
results
)
def
after_train_epoch
(
self
,
runner
):
"""After train epoch hook.
Override default ``multi_gpu_test``.
"""
if
not
self
.
by_epoch
or
not
self
.
every_n_epochs
(
runner
,
self
.
interval
):
return
from
annotator.uniformer.mmseg.apis
import
multi_gpu_test
runner
.
log_buffer
.
clear
()
results
=
multi_gpu_test
(
runner
.
model
,
self
.
dataloader
,
tmpdir
=
osp
.
join
(
runner
.
work_dir
,
'.eval_hook'
),
gpu_collect
=
self
.
gpu_collect
)
if
runner
.
rank
==
0
:
print
(
'
\n
'
)
self
.
evaluate
(
runner
,
results
)
lavis/common/annotator/uniformer/mmseg/core/evaluation/metrics.py
0 → 100644
View file @
c04f261a
from
collections
import
OrderedDict
import
annotator.uniformer.mmcv
as
mmcv
import
numpy
as
np
import
torch
def
f_score
(
precision
,
recall
,
beta
=
1
):
"""calcuate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
[torch.tensor]: The f-score value.
"""
score
=
(
1
+
beta
**
2
)
*
(
precision
*
recall
)
/
(
(
beta
**
2
*
precision
)
+
recall
)
return
score
def
intersect_and_union
(
pred_label
,
label
,
num_classes
,
ignore_index
,
label_map
=
dict
(),
reduce_zero_label
=
False
):
"""Calculate intersection and Union.
Args:
pred_label (ndarray | str): Prediction segmentation map
or predict result filename.
label (ndarray | str): Ground truth segmentation map
or label filename.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
work only when label is str. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. The parameter will
work only when label is str. Default: False.
Returns:
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
if
isinstance
(
pred_label
,
str
):
pred_label
=
torch
.
from_numpy
(
np
.
load
(
pred_label
))
else
:
pred_label
=
torch
.
from_numpy
((
pred_label
))
if
isinstance
(
label
,
str
):
label
=
torch
.
from_numpy
(
mmcv
.
imread
(
label
,
flag
=
'unchanged'
,
backend
=
'pillow'
))
else
:
label
=
torch
.
from_numpy
(
label
)
if
label_map
is
not
None
:
for
old_id
,
new_id
in
label_map
.
items
():
label
[
label
==
old_id
]
=
new_id
if
reduce_zero_label
:
label
[
label
==
0
]
=
255
label
=
label
-
1
label
[
label
==
254
]
=
255
mask
=
(
label
!=
ignore_index
)
pred_label
=
pred_label
[
mask
]
label
=
label
[
mask
]
intersect
=
pred_label
[
pred_label
==
label
]
area_intersect
=
torch
.
histc
(
intersect
.
float
(),
bins
=
(
num_classes
),
min
=
0
,
max
=
num_classes
-
1
)
area_pred_label
=
torch
.
histc
(
pred_label
.
float
(),
bins
=
(
num_classes
),
min
=
0
,
max
=
num_classes
-
1
)
area_label
=
torch
.
histc
(
label
.
float
(),
bins
=
(
num_classes
),
min
=
0
,
max
=
num_classes
-
1
)
area_union
=
area_pred_label
+
area_label
-
area_intersect
return
area_intersect
,
area_union
,
area_pred_label
,
area_label
def
total_intersect_and_union
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
label_map
=
dict
(),
reduce_zero_label
=
False
):
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs
=
len
(
results
)
assert
len
(
gt_seg_maps
)
==
num_imgs
total_area_intersect
=
torch
.
zeros
((
num_classes
,
),
dtype
=
torch
.
float64
)
total_area_union
=
torch
.
zeros
((
num_classes
,
),
dtype
=
torch
.
float64
)
total_area_pred_label
=
torch
.
zeros
((
num_classes
,
),
dtype
=
torch
.
float64
)
total_area_label
=
torch
.
zeros
((
num_classes
,
),
dtype
=
torch
.
float64
)
for
i
in
range
(
num_imgs
):
area_intersect
,
area_union
,
area_pred_label
,
area_label
=
\
intersect_and_union
(
results
[
i
],
gt_seg_maps
[
i
],
num_classes
,
ignore_index
,
label_map
,
reduce_zero_label
)
total_area_intersect
+=
area_intersect
total_area_union
+=
area_union
total_area_pred_label
+=
area_pred_label
total_area_label
+=
area_label
return
total_area_intersect
,
total_area_union
,
total_area_pred_label
,
\
total_area_label
def
mean_iou
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
nan_to_num
=
None
,
label_map
=
dict
(),
reduce_zero_label
=
False
):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
dict[str, float | ndarray]:
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<IoU> ndarray: Per category IoU, shape (num_classes, ).
"""
iou_result
=
eval_metrics
(
results
=
results
,
gt_seg_maps
=
gt_seg_maps
,
num_classes
=
num_classes
,
ignore_index
=
ignore_index
,
metrics
=
[
'mIoU'
],
nan_to_num
=
nan_to_num
,
label_map
=
label_map
,
reduce_zero_label
=
reduce_zero_label
)
return
iou_result
def
mean_dice
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
nan_to_num
=
None
,
label_map
=
dict
(),
reduce_zero_label
=
False
):
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<Dice> ndarray: Per category dice, shape (num_classes, ).
"""
dice_result
=
eval_metrics
(
results
=
results
,
gt_seg_maps
=
gt_seg_maps
,
num_classes
=
num_classes
,
ignore_index
=
ignore_index
,
metrics
=
[
'mDice'
],
nan_to_num
=
nan_to_num
,
label_map
=
label_map
,
reduce_zero_label
=
reduce_zero_label
)
return
dice_result
def
mean_fscore
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
nan_to_num
=
None
,
label_map
=
dict
(),
reduce_zero_label
=
False
,
beta
=
1
):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Fscore> ndarray: Per category recall, shape (num_classes, ).
<Precision> ndarray: Per category precision, shape (num_classes, ).
<Recall> ndarray: Per category f-score, shape (num_classes, ).
"""
fscore_result
=
eval_metrics
(
results
=
results
,
gt_seg_maps
=
gt_seg_maps
,
num_classes
=
num_classes
,
ignore_index
=
ignore_index
,
metrics
=
[
'mFscore'
],
nan_to_num
=
nan_to_num
,
label_map
=
label_map
,
reduce_zero_label
=
reduce_zero_label
,
beta
=
beta
)
return
fscore_result
def
eval_metrics
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
metrics
=
[
'mIoU'
],
nan_to_num
=
None
,
label_map
=
dict
(),
reduce_zero_label
=
False
,
beta
=
1
):
"""Calculate evaluation metrics
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if
isinstance
(
metrics
,
str
):
metrics
=
[
metrics
]
allowed_metrics
=
[
'mIoU'
,
'mDice'
,
'mFscore'
]
if
not
set
(
metrics
).
issubset
(
set
(
allowed_metrics
)):
raise
KeyError
(
'metrics {} is not supported'
.
format
(
metrics
))
total_area_intersect
,
total_area_union
,
total_area_pred_label
,
\
total_area_label
=
total_intersect_and_union
(
results
,
gt_seg_maps
,
num_classes
,
ignore_index
,
label_map
,
reduce_zero_label
)
all_acc
=
total_area_intersect
.
sum
()
/
total_area_label
.
sum
()
ret_metrics
=
OrderedDict
({
'aAcc'
:
all_acc
})
for
metric
in
metrics
:
if
metric
==
'mIoU'
:
iou
=
total_area_intersect
/
total_area_union
acc
=
total_area_intersect
/
total_area_label
ret_metrics
[
'IoU'
]
=
iou
ret_metrics
[
'Acc'
]
=
acc
elif
metric
==
'mDice'
:
dice
=
2
*
total_area_intersect
/
(
total_area_pred_label
+
total_area_label
)
acc
=
total_area_intersect
/
total_area_label
ret_metrics
[
'Dice'
]
=
dice
ret_metrics
[
'Acc'
]
=
acc
elif
metric
==
'mFscore'
:
precision
=
total_area_intersect
/
total_area_pred_label
recall
=
total_area_intersect
/
total_area_label
f_value
=
torch
.
tensor
(
[
f_score
(
x
[
0
],
x
[
1
],
beta
)
for
x
in
zip
(
precision
,
recall
)])
ret_metrics
[
'Fscore'
]
=
f_value
ret_metrics
[
'Precision'
]
=
precision
ret_metrics
[
'Recall'
]
=
recall
ret_metrics
=
{
metric
:
value
.
numpy
()
for
metric
,
value
in
ret_metrics
.
items
()
}
if
nan_to_num
is
not
None
:
ret_metrics
=
OrderedDict
({
metric
:
np
.
nan_to_num
(
metric_value
,
nan
=
nan_to_num
)
for
metric
,
metric_value
in
ret_metrics
.
items
()
})
return
ret_metrics
lavis/common/annotator/uniformer/mmseg/core/seg/__init__.py
0 → 100644
View file @
c04f261a
from
.builder
import
build_pixel_sampler
from
.sampler
import
BasePixelSampler
,
OHEMPixelSampler
__all__
=
[
'build_pixel_sampler'
,
'BasePixelSampler'
,
'OHEMPixelSampler'
]
lavis/common/annotator/uniformer/mmseg/core/seg/builder.py
0 → 100644
View file @
c04f261a
from
annotator.uniformer.mmcv.utils
import
Registry
,
build_from_cfg
PIXEL_SAMPLERS
=
Registry
(
'pixel sampler'
)
def
build_pixel_sampler
(
cfg
,
**
default_args
):
"""Build pixel sampler for segmentation map."""
return
build_from_cfg
(
cfg
,
PIXEL_SAMPLERS
,
default_args
)
lavis/common/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
0 → 100644
View file @
c04f261a
from
.base_pixel_sampler
import
BasePixelSampler
from
.ohem_pixel_sampler
import
OHEMPixelSampler
__all__
=
[
'BasePixelSampler'
,
'OHEMPixelSampler'
]
lavis/common/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
0 → 100644
View file @
c04f261a
from
abc
import
ABCMeta
,
abstractmethod
class
BasePixelSampler
(
metaclass
=
ABCMeta
):
"""Base class of pixel sampler."""
def
__init__
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
sample
(
self
,
seg_logit
,
seg_label
):
"""Placeholder for sample function."""
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment