Unverified Commit ff538ac4 authored by Yoshitomo Matsubara's avatar Yoshitomo Matsubara Committed by GitHub
Browse files

Fix repeated UserWarning and add more flexibility to reference code for segmentation tasks (#2886)

* add a README for training object detection models

* replaced np.asarray with np.array to avoid warning messages

* added data-path for flexibility

* fixed a typo
parent e987d1c0
......@@ -6,6 +6,12 @@ training and evaluation scripts to quickly bootstrap research.
All models have been trained on 8x V100 GPUs.
You must modify the following flags:
`--data-path=/path/to/dataset`
`--nproc_per_node=<number_of_gpus_available>`
## fcn_resnet50
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss
......
......@@ -12,13 +12,13 @@ import transforms as T
import utils
def get_dataset(name, image_set, transform):
def get_dataset(dir_path, name, image_set, transform):
def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs)
paths = {
"voc": ('/datasets01/VOC/060817/', torchvision.datasets.VOCSegmentation, 21),
"voc_aug": ('/datasets01/SBDD/072318/', sbd, 21),
"coco": ('/datasets01/COCO/022719/', get_coco, 21)
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
"voc_aug": (dir_path, sbd, 21),
"coco": (dir_path, get_coco, 21)
}
p, ds_fn, num_classes = paths[name]
......@@ -101,8 +101,8 @@ def main(args):
device = torch.device(args.device)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False))
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......@@ -186,7 +186,8 @@ def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Segmentation Training')
parser.add_argument('--dataset', default='voc', help='dataset')
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path')
parser.add_argument('--dataset', default='coco', help='dataset name')
parser.add_argument('--model', default='fcn_resnet101', help='model')
parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss')
parser.add_argument('--device', default='cuda', help='device')
......
......@@ -78,7 +78,7 @@ class CenterCrop(object):
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
target = torch.as_tensor(np.asarray(target), dtype=torch.int64)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment