Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
678f9334
Commit
678f9334
authored
Sep 23, 2018
by
Kai Chen
Browse files
refactor DataContainer and datasets
parent
b04a0157
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
151 additions
and
212 deletions
+151
-212
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+2
-3
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+40
-38
mmdet/datasets/data_engine.py
mmdet/datasets/data_engine.py
+0
-29
mmdet/datasets/loader/__init__.py
mmdet/datasets/loader/__init__.py
+7
-0
mmdet/datasets/loader/build_loader.py
mmdet/datasets/loader/build_loader.py
+39
-0
mmdet/datasets/loader/collate.py
mmdet/datasets/loader/collate.py
+17
-4
mmdet/datasets/loader/sampler.py
mmdet/datasets/loader/sampler.py
+0
-2
mmdet/datasets/transforms.py
mmdet/datasets/transforms.py
+15
-107
mmdet/datasets/utils/data_container.py
mmdet/datasets/utils/data_container.py
+7
-29
mmdet/datasets/utils/misc.py
mmdet/datasets/utils/misc.py
+24
-0
No files found.
mmdet/datasets/__init__.py
View file @
678f9334
from
.coco
import
CocoDataset
from
.coco
import
CocoDataset
from
.collate
import
*
from
.sampler
import
*
__all__
=
[
'CocoDataset'
]
from
.transforms
import
*
mmdet/datasets/coco.py
View file @
678f9334
...
@@ -7,7 +7,7 @@ from torch.utils.data import Dataset
...
@@ -7,7 +7,7 @@ from torch.utils.data import Dataset
from
.transforms
import
(
ImageTransform
,
BboxTransform
,
PolyMaskTransform
,
from
.transforms
import
(
ImageTransform
,
BboxTransform
,
PolyMaskTransform
,
Numpy2Tensor
)
Numpy2Tensor
)
from
.utils
import
show_ann
,
random_scale
from
.utils
import
to_tensor
,
show_ann
,
random_scale
from
.utils
import
DataContainer
as
DC
from
.utils
import
DataContainer
as
DC
...
@@ -71,6 +71,7 @@ def parse_ann_info(ann_info, cat2label, with_mask=True):
...
@@ -71,6 +71,7 @@ def parse_ann_info(ann_info, cat2label, with_mask=True):
class
CocoDataset
(
Dataset
):
class
CocoDataset
(
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
ann_file
,
ann_file
,
img_prefix
,
img_prefix
,
...
@@ -227,27 +228,28 @@ class CocoDataset(Dataset):
...
@@ -227,27 +228,28 @@ class CocoDataset(Dataset):
ann
[
'mask_polys'
],
ann
[
'poly_lens'
],
ann
[
'mask_polys'
],
ann
[
'poly_lens'
],
img_info
[
'height'
],
img_info
[
'width'
],
flip
)
img_info
[
'height'
],
img_info
[
'width'
],
flip
)
ori_shape
=
(
img_info
[
'height'
],
img_info
[
'width'
])
ori_shape
=
(
img_info
[
'height'
],
img_info
[
'width'
]
,
3
)
img_meta
=
dict
(
img_meta
=
dict
(
ori_shape
=
DC
(
ori_shape
)
,
ori_shape
=
ori_shape
,
img_shape
=
DC
(
img_shape
)
,
img_shape
=
img_shape
,
scale_factor
=
DC
(
scale_factor
)
,
scale_factor
=
scale_factor
,
flip
=
DC
(
flip
)
)
flip
=
flip
)
data
=
dict
(
data
=
dict
(
img
=
DC
(
img
,
stack
=
True
),
img
=
DC
(
to_tensor
(
img
)
,
stack
=
True
),
img_meta
=
img_meta
,
img_meta
=
DC
(
img_meta
,
cpu_only
=
True
),
gt_bboxes
=
DC
(
gt_bboxes
))
gt_bboxes
=
DC
(
to_tensor
(
gt_bboxes
))
)
if
self
.
proposals
is
not
None
:
if
self
.
proposals
is
not
None
:
data
[
'proposals'
]
=
DC
(
proposals
)
data
[
'proposals'
]
=
DC
(
to_tensor
(
proposals
)
)
if
self
.
with_label
:
if
self
.
with_label
:
data
[
'gt_labels'
]
=
DC
(
gt_labels
)
data
[
'gt_labels'
]
=
DC
(
to_tensor
(
gt_labels
)
)
if
self
.
with_crowd
:
if
self
.
with_crowd
:
data
[
'gt_bboxes_ignore'
]
=
DC
(
gt_bboxes_ignore
)
data
[
'gt_bboxes_ignore'
]
=
DC
(
to_tensor
(
gt_bboxes_ignore
)
)
if
self
.
with_mask
:
if
self
.
with_mask
:
data
[
'gt_mask_polys'
]
=
DC
(
gt_mask_polys
)
data
[
'gt_masks'
]
=
dict
(
data
[
'gt_poly_lens'
]
=
DC
(
gt_poly_lens
)
polys
=
DC
(
gt_mask_polys
,
cpu_only
=
True
),
data
[
'num_polys_per_mask'
]
=
DC
(
num_polys_per_mask
)
poly_lens
=
DC
(
gt_poly_lens
,
cpu_only
=
True
),
polys_per_mask
=
DC
(
num_polys_per_mask
,
cpu_only
=
True
))
return
data
return
data
def
prepare_test_img
(
self
,
idx
):
def
prepare_test_img
(
self
,
idx
):
...
@@ -258,37 +260,37 @@ class CocoDataset(Dataset):
...
@@ -258,37 +260,37 @@ class CocoDataset(Dataset):
if
self
.
proposals
is
not
None
else
None
)
if
self
.
proposals
is
not
None
else
None
)
def
prepare_single
(
img
,
scale
,
flip
,
proposal
=
None
):
def
prepare_single
(
img
,
scale
,
flip
,
proposal
=
None
):
_img
,
_
img_shape
,
_
scale_factor
=
self
.
img_transform
(
_img
,
img_shape
,
scale_factor
=
self
.
img_transform
(
img
,
scale
,
flip
)
img
,
scale
,
flip
)
img
,
img_shape
,
scale_factor
=
self
.
numpy2tensor
(
_img
=
to_tensor
(
_img
)
_img
,
_img_shape
,
_scale_factor
)
_img_meta
=
dict
(
ori_shape
=
(
img_info
[
'height'
],
img_info
[
'width'
])
ori_shape
=
(
img_info
[
'height'
],
img_info
[
'width'
],
3
),
img_meta
=
dict
(
ori_shape
=
ori_shape
,
img_shape
=
img_shape
,
img_shape
=
img_shape
,
scale_factor
=
scale_factor
,
scale_factor
=
scale_factor
,
flip
=
flip
)
flip
=
flip
)
if
proposal
is
not
None
:
if
proposal
is
not
None
:
proposal
=
self
.
bbox_transform
(
proposal
,
_scale_factor
,
flip
)
_proposal
=
self
.
bbox_transform
(
proposal
,
scale_factor
,
flip
)
proposal
=
self
.
numpy2tensor
(
proposal
)
_proposal
=
to_tensor
(
_proposal
)
return
img
,
img_meta
,
proposal
else
:
_proposal
=
None
return
_img
,
_img_meta
,
_proposal
imgs
=
[]
imgs
=
[]
img_metas
=
[]
img_metas
=
[]
proposals
=
[]
proposals
=
[]
for
scale
in
self
.
img_scales
:
for
scale
in
self
.
img_scales
:
img
,
img_meta
,
proposal
=
prepare_single
(
img
,
scale
,
False
,
_
img
,
_
img_meta
,
_
proposal
=
prepare_single
(
proposal
)
img
,
scale
,
False
,
proposal
)
imgs
.
append
(
img
)
imgs
.
append
(
_
img
)
img_metas
.
append
(
img_meta
)
img_metas
.
append
(
DC
(
_
img_meta
,
cpu_only
=
True
)
)
proposals
.
append
(
proposal
)
proposals
.
append
(
_
proposal
)
if
self
.
flip_ratio
>
0
:
if
self
.
flip_ratio
>
0
:
img
,
img_meta
,
prop
=
prepare_single
(
img
,
scale
,
True
,
_
img
,
_
img_meta
,
_
prop
osal
=
prepare_single
(
proposal
)
img
,
scale
,
True
,
proposal
)
imgs
.
append
(
img
)
imgs
.
append
(
_
img
)
img_metas
.
append
(
img_meta
)
img_metas
.
append
(
DC
(
_
img_meta
,
cpu_only
=
True
)
)
proposals
.
append
(
prop
)
proposals
.
append
(
_
prop
osal
)
if
self
.
proposals
is
None
:
data
=
dict
(
img
=
imgs
,
img_meta
=
img_metas
)
return
imgs
,
img_metas
if
self
.
proposals
is
not
None
:
else
:
data
[
'proposals'
]
=
proposals
return
imgs
,
img_metas
,
proposals
return
data
mmdet/datasets/data_engine.py
deleted
100644 → 0
View file @
b04a0157
from
functools
import
partial
import
torch
from
.coco
import
CocoDataset
from
.collate
import
collate
from
.sampler
import
GroupSampler
,
DistributedGroupSampler
def
build_data
(
cfg
,
args
):
dataset
=
CocoDataset
(
**
cfg
)
if
args
.
dist
:
sampler
=
DistributedGroupSampler
(
dataset
,
args
.
img_per_gpu
,
args
.
world_size
,
args
.
rank
)
batch_size
=
args
.
img_per_gpu
num_workers
=
args
.
data_workers
else
:
sampler
=
GroupSampler
(
dataset
,
args
.
img_per_gpu
)
batch_size
=
args
.
world_size
*
args
.
img_per_gpu
num_workers
=
args
.
world_size
*
args
.
data_workers
loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
img_per_gpu
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
args
.
img_per_gpu
),
pin_memory
=
False
)
return
loader
mmdet/datasets/loader/__init__.py
0 → 100644
View file @
678f9334
from
.build_loader
import
build_dataloader
from
.collate
import
collate
from
.sampler
import
GroupSampler
,
DistributedGroupSampler
__all__
=
[
'collate'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
]
mmdet/datasets/loader/build_loader.py
0 → 100644
View file @
678f9334
from
functools
import
partial
from
torch.utils.data
import
DataLoader
from
.collate
import
collate
from
.sampler
import
GroupSampler
,
DistributedGroupSampler
def
build_dataloader
(
dataset
,
imgs_per_gpu
,
workers_per_gpu
,
num_gpus
,
dist
=
True
,
world_size
=
1
,
rank
=
0
,
**
kwargs
):
if
dist
:
sampler
=
DistributedGroupSampler
(
dataset
,
imgs_per_gpu
,
world_size
,
rank
)
batch_size
=
imgs_per_gpu
num_workers
=
workers_per_gpu
else
:
sampler
=
GroupSampler
(
dataset
,
imgs_per_gpu
)
batch_size
=
num_gpus
*
imgs_per_gpu
num_workers
=
num_gpus
*
workers_per_gpu
if
not
kwargs
.
get
(
'shuffle'
,
True
):
sampler
=
None
data_loader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
imgs_per_gpu
),
pin_memory
=
False
,
**
kwargs
)
return
data_loader
mmdet/datasets/collate.py
→
mmdet/datasets/
loader/
collate.py
View file @
678f9334
...
@@ -4,17 +4,24 @@ import torch
...
@@ -4,17 +4,24 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.utils.data.dataloader
import
default_collate
from
torch.utils.data.dataloader
import
default_collate
from
.utils
import
DataContainer
from
.
.utils
import
DataContainer
# https://github.com/pytorch/pytorch/issues/973
# https://github.com/pytorch/pytorch/issues/973
import
resource
import
resource
rlimit
=
resource
.
getrlimit
(
resource
.
RLIMIT_NOFILE
)
rlimit
=
resource
.
getrlimit
(
resource
.
RLIMIT_NOFILE
)
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
4096
,
rlimit
[
1
]))
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
4096
,
rlimit
[
1
]))
__all__
=
[
'collate'
]
def
collate
(
batch
,
samples_per_gpu
=
1
):
def
collate
(
batch
,
samples_per_gpu
=
1
):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
Extend default_collate to add support for :type:`~mmdet.DataContainer`.
There are 3 cases for data containers.
1. cpu_only = True, e.g., meta data
2. cpu_only = False, stack = True, e.g., images tensors
3. cpu_only = False, stack = False, e.g., gt bboxes
"""
if
not
isinstance
(
batch
,
collections
.
Sequence
):
if
not
isinstance
(
batch
,
collections
.
Sequence
):
raise
TypeError
(
"{} is not supported."
.
format
(
batch
.
dtype
))
raise
TypeError
(
"{} is not supported."
.
format
(
batch
.
dtype
))
...
@@ -22,7 +29,13 @@ def collate(batch, samples_per_gpu=1):
...
@@ -22,7 +29,13 @@ def collate(batch, samples_per_gpu=1):
if
isinstance
(
batch
[
0
],
DataContainer
):
if
isinstance
(
batch
[
0
],
DataContainer
):
assert
len
(
batch
)
%
samples_per_gpu
==
0
assert
len
(
batch
)
%
samples_per_gpu
==
0
stacked
=
[]
stacked
=
[]
if
batch
[
0
].
stack
:
if
batch
[
0
].
cpu_only
:
for
i
in
range
(
0
,
len
(
batch
),
samples_per_gpu
):
stacked
.
append
(
[
sample
.
data
for
sample
in
batch
[
i
:
i
+
samples_per_gpu
]])
return
DataContainer
(
stacked
,
batch
[
0
].
stack
,
batch
[
0
].
padding_value
,
cpu_only
=
True
)
elif
batch
[
0
].
stack
:
for
i
in
range
(
0
,
len
(
batch
),
samples_per_gpu
):
for
i
in
range
(
0
,
len
(
batch
),
samples_per_gpu
):
assert
isinstance
(
batch
[
i
].
data
,
torch
.
Tensor
)
assert
isinstance
(
batch
[
i
].
data
,
torch
.
Tensor
)
# TODO: handle tensors other than 3d
# TODO: handle tensors other than 3d
...
...
mmdet/datasets/sampler.py
→
mmdet/datasets/
loader/
sampler.py
View file @
678f9334
...
@@ -7,8 +7,6 @@ import numpy as np
...
@@ -7,8 +7,6 @@ import numpy as np
from
torch.distributed
import
get_world_size
,
get_rank
from
torch.distributed
import
get_world_size
,
get_rank
from
torch.utils.data.sampler
import
Sampler
from
torch.utils.data.sampler
import
Sampler
__all__
=
[
'GroupSampler'
,
'DistributedGroupSampler'
]
class
GroupSampler
(
Sampler
):
class
GroupSampler
(
Sampler
):
...
...
mmdet/datasets/transforms.py
View file @
678f9334
...
@@ -29,7 +29,7 @@ class ImageTransform(object):
...
@@ -29,7 +29,7 @@ class ImageTransform(object):
self
.
size_divisor
=
size_divisor
self
.
size_divisor
=
size_divisor
def
__call__
(
self
,
img
,
scale
,
flip
=
False
):
def
__call__
(
self
,
img
,
scale
,
flip
=
False
):
img
,
scale_factor
=
mmcv
.
imrescale
(
img
,
scale
,
True
)
img
,
scale_factor
=
mmcv
.
imrescale
(
img
,
scale
,
return_scale
=
True
)
img_shape
=
img
.
shape
img_shape
=
img
.
shape
img
=
mmcv
.
imnorm
(
img
,
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
img
=
mmcv
.
imnorm
(
img
,
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
if
flip
:
if
flip
:
...
@@ -39,76 +39,20 @@ class ImageTransform(object):
...
@@ -39,76 +39,20 @@ class ImageTransform(object):
img
=
img
.
transpose
(
2
,
0
,
1
)
img
=
img
.
transpose
(
2
,
0
,
1
)
return
img
,
img_shape
,
scale_factor
return
img
,
img_shape
,
scale_factor
# img, scale = cvb.resize_keep_ar(img_or_path, max_long_edge,
# max_short_edge, True)
# shape_scale = np.array(img.shape + (scale, ), dtype=np.float32)
# if flip:
# img = img[:, ::-1, :].copy()
# if self.color_order == 'RGB':
# img = cvb.bgr2rgb(img)
# img = img.astype(np.float32)
# img -= self.color_mean
# img /= self.color_std
# if self.size_divisor is None:
# padded_img = img
# else:
# pad_h = int(np.ceil(
# img.shape[0] / self.size_divisor)) * self.size_divisor
# pad_w = int(np.ceil(
# img.shape[1] / self.size_divisor)) * self.size_divisor
# padded_img = cvb.pad_img(img, (pad_h, pad_w), pad_val=0)
# padded_img = padded_img.transpose(2, 0, 1)
# return padded_img, shape_scale
class
ImageCrop
(
object
):
"""crop image patches and resize patches into fixed size
1. (read and) flip image (if needed)
2. crop image patches according to given bboxes
3. resize patches into fixed size (default 224x224)
4. normalize the image (if needed)
5. transpose to (c, h, w) (if needed)
"""
def
__init__
(
self
,
def
bbox_flip
(
bboxes
,
img_shape
):
normalize
=
True
,
"""Flip bboxes horizontally.
transpose
=
True
,
color_order
=
'RGB'
,
Args:
color_mean
=
(
0
,
0
,
0
),
bboxes(ndarray): shape (..., 4*k)
color_std
=
(
1
,
1
,
1
)):
img_shape(tuple): (height, width)
self
.
normalize
=
normalize
"""
self
.
transpose
=
transpose
assert
bboxes
.
shape
[
-
1
]
%
4
==
0
w
=
img_shape
[
1
]
assert
color_order
in
[
'RGB'
,
'BGR'
]
flipped
=
bboxes
.
copy
()
self
.
color_order
=
color_order
flipped
[...,
0
::
4
]
=
w
-
bboxes
[...,
2
::
4
]
-
1
self
.
color_mean
=
np
.
array
(
color_mean
,
dtype
=
np
.
float32
)
flipped
[...,
2
::
4
]
=
w
-
bboxes
[...,
0
::
4
]
-
1
self
.
color_std
=
np
.
array
(
color_std
,
dtype
=
np
.
float32
)
return
flipped
def
__call__
(
self
,
img_or_path
,
bboxes
,
crop_size
,
scale_ratio
=
1.0
,
flip
=
False
):
img
=
cvb
.
read_img
(
img_or_path
)
if
flip
:
img
=
img
[:,
::
-
1
,
:].
copy
()
crop_imgs
=
cvb
.
crop_img
(
img
,
bboxes
[:,
:
4
],
scale_ratio
=
scale_ratio
,
pad_fill
=
self
.
color_mean
)
processed_crop_imgs_list
=
[]
for
i
in
range
(
len
(
crop_imgs
)):
crop_img
=
crop_imgs
[
i
]
crop_img
=
cvb
.
resize
(
crop_img
,
crop_size
)
crop_img
=
crop_img
.
astype
(
np
.
float32
)
crop_img
-=
self
.
color_mean
crop_img
/=
self
.
color_std
processed_crop_imgs_list
.
append
(
crop_img
)
processed_crop_imgs
=
np
.
stack
(
processed_crop_imgs_list
,
axis
=
0
)
processed_crop_imgs
=
processed_crop_imgs
.
transpose
(
0
,
3
,
1
,
2
)
return
processed_crop_imgs
class
BboxTransform
(
object
):
class
BboxTransform
(
object
):
...
@@ -124,7 +68,7 @@ class BboxTransform(object):
...
@@ -124,7 +68,7 @@ class BboxTransform(object):
def
__call__
(
self
,
bboxes
,
img_shape
,
scale_factor
,
flip
=
False
):
def
__call__
(
self
,
bboxes
,
img_shape
,
scale_factor
,
flip
=
False
):
gt_bboxes
=
bboxes
*
scale_factor
gt_bboxes
=
bboxes
*
scale_factor
if
flip
:
if
flip
:
gt_bboxes
=
mmcv
.
bbox_flip
(
gt_bboxes
,
img_shape
)
gt_bboxes
=
bbox_flip
(
gt_bboxes
,
img_shape
)
gt_bboxes
[:,
0
::
2
]
=
np
.
clip
(
gt_bboxes
[:,
0
::
2
],
0
,
img_shape
[
1
])
gt_bboxes
[:,
0
::
2
]
=
np
.
clip
(
gt_bboxes
[:,
0
::
2
],
0
,
img_shape
[
1
])
gt_bboxes
[:,
1
::
2
]
=
np
.
clip
(
gt_bboxes
[:,
1
::
2
],
0
,
img_shape
[
0
])
gt_bboxes
[:,
1
::
2
]
=
np
.
clip
(
gt_bboxes
[:,
1
::
2
],
0
,
img_shape
[
0
])
if
self
.
max_num_gts
is
None
:
if
self
.
max_num_gts
is
None
:
...
@@ -161,42 +105,6 @@ class PolyMaskTransform(object):
...
@@ -161,42 +105,6 @@ class PolyMaskTransform(object):
return
gt_mask_polys
,
gt_poly_lens
,
num_polys_per_mask
return
gt_mask_polys
,
gt_poly_lens
,
num_polys_per_mask
class
MaskTransform
(
object
):
"""Preprocess masks
1. resize masks to expected size and stack to a single array
2. flip the masks (if needed)
3. pad the masks (if needed)
"""
def
__init__
(
self
,
max_num_gts
,
pad_size
=
None
):
self
.
max_num_gts
=
max_num_gts
self
.
pad_size
=
pad_size
def
__call__
(
self
,
masks
,
img_size
,
flip
=
False
):
max_long_edge
=
max
(
img_size
)
max_short_edge
=
min
(
img_size
)
masks
=
[
cvb
.
resize_keep_ar
(
mask
,
max_long_edge
,
max_short_edge
,
interpolation
=
cvb
.
INTER_NEAREST
)
for
mask
in
masks
]
masks
=
np
.
stack
(
masks
,
axis
=
0
)
if
flip
:
masks
=
masks
[:,
::
-
1
,
:]
if
self
.
pad_size
is
None
:
pad_h
=
masks
.
shape
[
1
]
pad_w
=
masks
.
shape
[
2
]
else
:
pad_size
=
self
.
pad_size
if
self
.
pad_size
>
0
else
max_long_edge
pad_h
=
pad_w
=
pad_size
padded_masks
=
np
.
zeros
(
(
self
.
max_num_gts
,
pad_h
,
pad_w
),
dtype
=
masks
.
dtype
)
padded_masks
[:
masks
.
shape
[
0
],
:
masks
.
shape
[
1
],
:
masks
.
shape
[
2
]]
=
masks
return
padded_masks
class
Numpy2Tensor
(
object
):
class
Numpy2Tensor
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
mmdet/datasets/utils/data_container.py
View file @
678f9334
import
functools
import
functools
from
collections
import
Sequence
import
mmcv
import
numpy
as
np
import
torch
import
torch
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if
isinstance
(
data
,
np
.
ndarray
):
return
torch
.
from_numpy
(
data
)
elif
isinstance
(
data
,
torch
.
Tensor
):
return
data
elif
isinstance
(
data
,
Sequence
)
and
not
mmcv
.
is_str
(
data
):
return
torch
.
tensor
(
data
)
elif
isinstance
(
data
,
int
):
return
torch
.
LongTensor
([
data
])
elif
isinstance
(
data
,
float
):
return
torch
.
FloatTensor
([
data
])
else
:
raise
TypeError
(
'type {} cannot be converted to tensor.'
.
format
(
type
(
data
)))
def
assert_tensor_type
(
func
):
def
assert_tensor_type
(
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
...
@@ -41,11 +17,9 @@ def assert_tensor_type(func):
...
@@ -41,11 +17,9 @@ def assert_tensor_type(func):
class
DataContainer
(
object
):
class
DataContainer
(
object
):
def
__init__
(
self
,
data
,
stack
=
False
,
padding_value
=
0
):
def
__init__
(
self
,
data
,
stack
=
False
,
padding_value
=
0
,
cpu_only
=
False
):
if
isinstance
(
data
,
list
):
self
.
_data
=
data
self
.
_data
=
data
self
.
_cpu_only
=
cpu_only
else
:
self
.
_data
=
to_tensor
(
data
)
self
.
_stack
=
stack
self
.
_stack
=
stack
self
.
_padding_value
=
padding_value
self
.
_padding_value
=
padding_value
...
@@ -63,6 +37,10 @@ class DataContainer(object):
...
@@ -63,6 +37,10 @@ class DataContainer(object):
else
:
else
:
return
type
(
self
.
data
)
return
type
(
self
.
data
)
@
property
def
cpu_only
(
self
):
return
self
.
_cpu_only
@
property
@
property
def
stack
(
self
):
def
stack
(
self
):
return
self
.
_stack
return
self
.
_stack
...
...
mmdet/datasets/utils/misc.py
View file @
678f9334
from
collections
import
Sequence
import
mmcv
import
mmcv
import
torch
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
pycocotools.mask
as
maskUtils
import
pycocotools.mask
as
maskUtils
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
elif
isinstance
(
data
,
np
.
ndarray
):
return
torch
.
from_numpy
(
data
)
elif
isinstance
(
data
,
Sequence
)
and
not
mmcv
.
is_str
(
data
):
return
torch
.
tensor
(
data
)
elif
isinstance
(
data
,
int
):
return
torch
.
LongTensor
([
data
])
elif
isinstance
(
data
,
float
):
return
torch
.
FloatTensor
([
data
])
else
:
raise
TypeError
(
'type {} cannot be converted to tensor.'
.
format
(
type
(
data
)))
def
random_scale
(
img_scales
,
mode
=
'range'
):
def
random_scale
(
img_scales
,
mode
=
'range'
):
"""Randomly select a scale from a list of scales or scale ranges.
"""Randomly select a scale from a list of scales or scale ranges.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment