"docs/source/vscode:/vscode.git/clone" did not exist on "21464e055b617b9a1c241d440f87c8efc2745e75"
Commit b7615843 authored by flauted's avatar flauted Committed by Francisco Massa
Browse files

Clean det ref (#1109)

* Doc multigpu and propagate data path.

* Use raw doc because of backslash.
parent dea1afbf
r"""PyTorch Detection Training.
To run in a multi-gpu environment, use the distributed launcher::
python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
train.py ... --world-size $NGPU
"""
import datetime
import os
import time
......@@ -18,10 +26,10 @@ import utils
import transforms as T
def get_dataset(name, image_set, transform):
def get_dataset(name, image_set, transform, data_path):
paths = {
"coco": ('/datasets01/COCO/022719/', get_coco, 91),
"coco_kp": ('/datasets01/COCO/022719/', get_coco_kp, 2)
"coco": (data_path, get_coco, 91),
"coco_kp": (data_path, get_coco_kp, 2)
}
p, ds_fn, num_classes = paths[name]
......@@ -46,8 +54,8 @@ def main(args):
# Data loading code
print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False))
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
print("Creating data loaders")
if args.distributed:
......@@ -125,7 +133,8 @@ def main(args):
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='PyTorch Detection Training')
parser = argparse.ArgumentParser(
description=__doc__)
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
parser.add_argument('--dataset', default='coco', help='dataset')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment