Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b7615843
Commit
b7615843
authored
Jul 12, 2019
by
flauted
Committed by
Francisco Massa
Jul 12, 2019
Browse files
Clean det ref (#1109)
* Doc multigpu and propagate data path. * Use raw doc because of backslash.
parent
dea1afbf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
6 deletions
+15
-6
references/detection/train.py
references/detection/train.py
+15
-6
No files found.
references/detection/train.py
View file @
b7615843
r
"""PyTorch Detection Training.
To run in a multi-gpu environment, use the distributed launcher::
python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
train.py ... --world-size $NGPU
"""
import
datetime
import
datetime
import
os
import
os
import
time
import
time
...
@@ -18,10 +26,10 @@ import utils
...
@@ -18,10 +26,10 @@ import utils
import
transforms
as
T
import
transforms
as
T
def
get_dataset
(
name
,
image_set
,
transform
):
def
get_dataset
(
name
,
image_set
,
transform
,
data_path
):
paths
=
{
paths
=
{
"coco"
:
(
'/
data
sets01/COCO/022719/'
,
get_coco
,
91
),
"coco"
:
(
data
_path
,
get_coco
,
91
),
"coco_kp"
:
(
'/
data
sets01/COCO/022719/'
,
get_coco_kp
,
2
)
"coco_kp"
:
(
data
_path
,
get_coco_kp
,
2
)
}
}
p
,
ds_fn
,
num_classes
=
paths
[
name
]
p
,
ds_fn
,
num_classes
=
paths
[
name
]
...
@@ -46,8 +54,8 @@ def main(args):
...
@@ -46,8 +54,8 @@ def main(args):
# Data loading code
# Data loading code
print
(
"Loading data"
)
print
(
"Loading data"
)
dataset
,
num_classes
=
get_dataset
(
args
.
dataset
,
"train"
,
get_transform
(
train
=
True
))
dataset
,
num_classes
=
get_dataset
(
args
.
dataset
,
"train"
,
get_transform
(
train
=
True
)
,
args
.
data_path
)
dataset_test
,
_
=
get_dataset
(
args
.
dataset
,
"val"
,
get_transform
(
train
=
False
))
dataset_test
,
_
=
get_dataset
(
args
.
dataset
,
"val"
,
get_transform
(
train
=
False
)
,
args
.
data_path
)
print
(
"Creating data loaders"
)
print
(
"Creating data loaders"
)
if
args
.
distributed
:
if
args
.
distributed
:
...
@@ -125,7 +133,8 @@ def main(args):
...
@@ -125,7 +133,8 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Detection Training'
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--data-path'
,
default
=
'/datasets01/COCO/022719/'
,
help
=
'dataset'
)
parser
.
add_argument
(
'--data-path'
,
default
=
'/datasets01/COCO/022719/'
,
help
=
'dataset'
)
parser
.
add_argument
(
'--dataset'
,
default
=
'coco'
,
help
=
'dataset'
)
parser
.
add_argument
(
'--dataset'
,
default
=
'coco'
,
help
=
'dataset'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment