Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
14a7dfb9
Unverified
Commit
14a7dfb9
authored
Oct 10, 2018
by
Kai Chen
Committed by
GitHub
Oct 10, 2018
Browse files
Merge pull request #11 from hellock/dev
Update backbone and setup scripts
parents
e8397e43
a6adf8f0
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
99 additions
and
106 deletions
+99
-106
.travis.yml
.travis.yml
+13
-0
configs/faster_rcnn_r50_fpn_1x.py
configs/faster_rcnn_r50_fpn_1x.py
+1
-1
configs/mask_rcnn_r50_fpn_1x.py
configs/mask_rcnn_r50_fpn_1x.py
+1
-1
configs/rpn_r50_fpn_1x.py
configs/rpn_r50_fpn_1x.py
+1
-1
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+5
-1
mmdet/models/__init__.py
mmdet/models/__init__.py
+6
-4
mmdet/models/backbones/__init__.py
mmdet/models/backbones/__init__.py
+2
-2
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+58
-86
mmdet/models/detectors/__init__.py
mmdet/models/detectors/__init__.py
+5
-1
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+3
-6
setup.py
setup.py
+2
-1
tools/test.py
tools/test.py
+2
-2
No files found.
.travis.yml
0 → 100644
View file @
14a7dfb9
dist
:
trusty
language
:
python
install
:
-
pip install flake8
python
:
-
"
2.7"
-
"
3.5"
-
"
3.6"
script
:
-
flake8
\ No newline at end of file
configs/faster_rcnn_r50_fpn_1x.py
View file @
14a7dfb9
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'FasterRCNN'
,
type
=
'FasterRCNN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
configs/mask_rcnn_r50_fpn_1x.py
View file @
14a7dfb9
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'MaskRCNN'
,
type
=
'MaskRCNN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
configs/rpn_r50_fpn_1x.py
View file @
14a7dfb9
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'RPN'
,
type
=
'RPN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
14a7dfb9
...
@@ -55,6 +55,10 @@ class DistEvalHook(Hook):
...
@@ -55,6 +55,10 @@ class DistEvalHook(Hook):
shutil
.
rmtree
(
self
.
lock_dir
)
shutil
.
rmtree
(
self
.
lock_dir
)
mmcv
.
mkdir_or_exist
(
self
.
lock_dir
)
mmcv
.
mkdir_or_exist
(
self
.
lock_dir
)
def
after_run
(
self
,
runner
):
if
runner
.
rank
==
0
:
shutil
.
rmtree
(
self
.
lock_dir
)
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
):
if
not
self
.
every_n_epochs
(
runner
,
self
.
interval
):
if
not
self
.
every_n_epochs
(
runner
,
self
.
interval
):
return
return
...
@@ -70,7 +74,7 @@ class DistEvalHook(Hook):
...
@@ -70,7 +74,7 @@ class DistEvalHook(Hook):
# compute output
# compute output
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
runner
.
model
(
result
=
runner
.
model
(
**
data_gpu
,
return_loss
=
False
,
rescale
=
True
)
return_loss
=
False
,
rescale
=
True
,
**
data_gpu
)
results
[
idx
]
=
result
results
[
idx
]
=
result
batch_size
=
runner
.
world_size
batch_size
=
runner
.
world_size
...
...
mmdet/models/__init__.py
View file @
14a7dfb9
from
.detectors
import
BaseDetector
,
RPN
,
FasterRCNN
,
MaskRCNN
from
.detectors
import
(
BaseDetector
,
TwoStageDetector
,
RPN
,
FastRCNN
,
FasterRCNN
,
MaskRCNN
)
from
.builder
import
(
build_neck
,
build_rpn_head
,
build_roi_extractor
,
from
.builder
import
(
build_neck
,
build_rpn_head
,
build_roi_extractor
,
build_bbox_head
,
build_mask_head
,
build_detector
)
build_bbox_head
,
build_mask_head
,
build_detector
)
__all__
=
[
__all__
=
[
'BaseDetector'
,
'RPN'
,
'FasterRCNN'
,
'MaskRCNN'
,
'build_backbone'
,
'BaseDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'MaskRCNN'
,
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_mask_head'
,
'build_detector'
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_detector'
]
]
mmdet/models/backbones/__init__.py
View file @
14a7dfb9
from
.resnet
import
r
es
n
et
from
.resnet
import
R
es
N
et
__all__
=
[
'
r
es
n
et'
]
__all__
=
[
'
R
es
N
et'
]
mmdet/models/backbones/resnet.py
View file @
14a7dfb9
import
logging
import
logging
import
math
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
...
@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
...
@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
):
style
=
'pytorch'
,
with_cp
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
...
@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
...
@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
self
.
downsample
=
downsample
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
dilation
=
dilation
assert
not
with_cp
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
...
@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
):
"""Bottleneck block.
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
"""
...
@@ -174,64 +176,73 @@ def make_res_layer(block,
...
@@ -174,64 +176,73 @@ def make_res_layer(block,
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
class
ResHead
(
nn
.
Module
):
class
ResNet
(
nn
.
Module
):
"""ResNet backbone.
def
__init__
(
self
,
block
,
num_blocks
,
stride
=
2
,
dilation
=
1
,
style
=
'pytorch'
):
self
.
layer4
=
make_res_layer
(
block
,
1024
,
512
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
style
=
style
)
def
forward
(
self
,
x
):
return
self
.
layer4
(
x
)
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
class
ResNet
(
nn
.
Module
):
arch_settings
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
def
__init__
(
self
,
block
,
depth
,
layers
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=-
1
,
style
=
'pytorch'
,
style
=
'pytorch'
,
sync_bn
=
False
,
frozen_stages
=-
1
,
with_cp
=
False
,
bn_eval
=
True
,
strict_frozen
=
False
):
bn_frozen
=
False
,
with_cp
=
False
):
super
(
ResNet
,
self
).
__init__
()
super
(
ResNet
,
self
).
__init__
()
if
not
len
(
layers
)
==
len
(
strides
)
==
len
(
dilations
):
if
depth
not
in
self
.
arch_settings
:
raise
ValueError
(
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
'The number of layers, strides and dilations must be equal, '
assert
num_stages
>=
1
and
num_stages
<=
4
'but found have {} layers, {} strides and {} dilations'
.
format
(
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
len
(
layers
),
len
(
strides
),
len
(
dilations
)))
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
max
(
out_indices
)
<
len
(
layers
)
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
style
=
style
self
.
style
=
style
self
.
sync_bn
=
sync_bn
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
self
.
inplanes
=
64
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
layers
):
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
stage_blocks
):
stride
=
strides
[
i
]
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
dilation
=
dilations
[
i
]
layer_name
=
'layer{}'
.
format
(
i
+
1
)
planes
=
64
*
2
**
i
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
res_layer
=
make_res_layer
(
block
,
block
,
...
@@ -243,12 +254,11 @@ class ResNet(nn.Module):
...
@@ -243,12 +254,11 @@ class ResNet(nn.Module):
style
=
self
.
style
,
style
=
self
.
style
,
with_cp
=
with_cp
)
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
self
.
inplanes
=
planes
*
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
layers
)
-
1
)
self
.
with_cp
=
with_cp
self
.
strict_frozen
=
strict_frozen
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
stage_blocks
)
-
1
)
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
...
@@ -257,11 +267,9 @@ class ResNet(nn.Module):
...
@@ -257,11 +267,9 @@ class ResNet(nn.Module):
elif
pretrained
is
None
:
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
kaiming_init
(
m
)
nn
.
init
.
normal_
(
m
.
weight
,
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
constant_init
(
m
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
else
:
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
raise
TypeError
(
'pretrained must be a str or None'
)
...
@@ -283,11 +291,11 @@ class ResNet(nn.Module):
...
@@ -283,11 +291,11 @@ class ResNet(nn.Module):
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
super
(
ResNet
,
self
).
train
(
mode
)
if
not
self
.
sync_bn
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
m
.
eval
()
if
self
.
strict
_frozen
:
if
self
.
bn
_frozen
:
for
params
in
m
.
parameters
():
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
if
mode
and
self
.
frozen_stages
>=
0
:
...
@@ -303,39 +311,3 @@ class ResNet(nn.Module):
...
@@ -303,39 +311,3 @@ class ResNet(nn.Module):
mod
.
eval
()
mod
.
eval
()
for
param
in
mod
.
parameters
():
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
resnet_cfg
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
resnet
(
depth
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
2
,
),
frozen_stages
=-
1
,
style
=
'pytorch'
,
sync_bn
=
False
,
with_cp
=
False
,
strict_frozen
=
False
):
"""Constructs a ResNet model.
Args:
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
num_stages (int): num of resnet stages, normally 4
strides (list): strides of the first block of each stage
dilations (list): dilation of each stage
out_indices (list): output from which stages
"""
if
depth
not
in
resnet_cfg
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
block
,
layers
=
resnet_cfg
[
depth
]
model
=
ResNet
(
block
,
layers
[:
num_stages
],
strides
,
dilations
,
out_indices
,
frozen_stages
,
style
,
sync_bn
,
with_cp
,
strict_frozen
)
return
model
mmdet/models/detectors/__init__.py
View file @
14a7dfb9
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.two_stage
import
TwoStageDetector
from
.rpn
import
RPN
from
.rpn
import
RPN
from
.fast_rcnn
import
FastRCNN
from
.fast_rcnn
import
FastRCNN
from
.faster_rcnn
import
FasterRCNN
from
.faster_rcnn
import
FasterRCNN
from
.mask_rcnn
import
MaskRCNN
from
.mask_rcnn
import
MaskRCNN
__all__
=
[
'BaseDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
]
__all__
=
[
'BaseDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
]
mmdet/models/detectors/base.py
View file @
14a7dfb9
...
@@ -3,7 +3,6 @@ from abc import ABCMeta, abstractmethod
...
@@ -3,7 +3,6 @@ from abc import ABCMeta, abstractmethod
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmdet.core
import
tensor2imgs
,
get_classes
from
mmdet.core
import
tensor2imgs
,
get_classes
...
@@ -34,11 +33,9 @@ class BaseDetector(nn.Module):
...
@@ -34,11 +33,9 @@ class BaseDetector(nn.Module):
pass
pass
def
extract_feats
(
self
,
imgs
):
def
extract_feats
(
self
,
imgs
):
if
isinstance
(
imgs
,
torch
.
Tensor
):
assert
isinstance
(
imgs
,
list
)
return
self
.
extract_feat
(
imgs
)
for
img
in
imgs
:
elif
isinstance
(
imgs
,
list
):
yield
self
.
extract_feat
(
img
)
for
img
in
imgs
:
yield
self
.
extract_feat
(
img
)
@
abstractmethod
@
abstractmethod
def
forward_train
(
self
,
imgs
,
img_metas
,
**
kwargs
):
def
forward_train
(
self
,
imgs
,
img_metas
,
**
kwargs
):
...
...
setup.py
View file @
14a7dfb9
...
@@ -106,6 +106,7 @@ if __name__ == '__main__':
...
@@ -106,6 +106,7 @@ if __name__ == '__main__':
setup_requires
=
[
'pytest-runner'
],
setup_requires
=
[
'pytest-runner'
],
tests_require
=
[
'pytest'
],
tests_require
=
[
'pytest'
],
install_requires
=
[
install_requires
=
[
'numpy'
,
'matplotlib'
,
'six'
,
'terminaltables'
,
'pycocotools'
'mmcv'
,
'numpy'
,
'matplotlib'
,
'six'
,
'terminaltables'
,
'pycocotools'
],
],
zip_safe
=
False
)
zip_safe
=
False
)
tools/test.py
View file @
14a7dfb9
...
@@ -17,7 +17,7 @@ def single_test(model, data_loader, show=False):
...
@@ -17,7 +17,7 @@ def single_test(model, data_loader, show=False):
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
**
data
,
return_loss
=
False
,
rescale
=
not
show
)
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
results
.
append
(
result
)
results
.
append
(
result
)
if
show
:
if
show
:
...
@@ -32,7 +32,7 @@ def single_test(model, data_loader, show=False):
...
@@ -32,7 +32,7 @@ def single_test(model, data_loader, show=False):
def
_data_func
(
data
,
device_id
):
def
_data_func
(
data
,
device_id
):
data
=
scatter
(
collate
([
data
],
samples_per_gpu
=
1
),
[
device_id
])[
0
]
data
=
scatter
(
collate
([
data
],
samples_per_gpu
=
1
),
[
device_id
])[
0
]
return
dict
(
**
data
,
return_loss
=
False
,
rescale
=
True
)
return
dict
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
def
parse_args
():
def
parse_args
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment