Unverified Commit ff538ac4 authored by Yoshitomo Matsubara's avatar Yoshitomo Matsubara Committed by GitHub
Browse files

Fix repeated UserWarning and add more flexibility to reference code for segmentation tasks (#2886)

* add a README for training object detection models

* replaced np.asarray with np.array to avoid warning messages

* added data-path for flexibility

* fixed a typo
parent e987d1c0
...@@ -6,6 +6,12 @@ training and evaluation scripts to quickly bootstrap research. ...@@ -6,6 +6,12 @@ training and evaluation scripts to quickly bootstrap research.
All models have been trained on 8x V100 GPUs. All models have been trained on 8x V100 GPUs.
You must modify the following flags:
`--data-path=/path/to/dataset`
`--nproc_per_node=<number_of_gpus_available>`
## fcn_resnet50 ## fcn_resnet50
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss
......
...@@ -12,13 +12,13 @@ import transforms as T ...@@ -12,13 +12,13 @@ import transforms as T
import utils import utils
def get_dataset(name, image_set, transform): def get_dataset(dir_path, name, image_set, transform):
def sbd(*args, **kwargs): def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs)
paths = { paths = {
"voc": ('/datasets01/VOC/060817/', torchvision.datasets.VOCSegmentation, 21), "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
"voc_aug": ('/datasets01/SBDD/072318/', sbd, 21), "voc_aug": (dir_path, sbd, 21),
"coco": ('/datasets01/COCO/022719/', get_coco, 21) "coco": (dir_path, get_coco, 21)
} }
p, ds_fn, num_classes = paths[name] p, ds_fn, num_classes = paths[name]
...@@ -101,8 +101,8 @@ def main(args): ...@@ -101,8 +101,8 @@ def main(args):
device = torch.device(args.device) device = torch.device(args.device)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
...@@ -186,7 +186,8 @@ def parse_args(): ...@@ -186,7 +186,8 @@ def parse_args():
import argparse import argparse
parser = argparse.ArgumentParser(description='PyTorch Segmentation Training') parser = argparse.ArgumentParser(description='PyTorch Segmentation Training')
parser.add_argument('--dataset', default='voc', help='dataset') parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path')
parser.add_argument('--dataset', default='coco', help='dataset name')
parser.add_argument('--model', default='fcn_resnet101', help='model') parser.add_argument('--model', default='fcn_resnet101', help='model')
parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss')
parser.add_argument('--device', default='cuda', help='device') parser.add_argument('--device', default='cuda', help='device')
......
...@@ -78,7 +78,7 @@ class CenterCrop(object): ...@@ -78,7 +78,7 @@ class CenterCrop(object):
class ToTensor(object): class ToTensor(object):
def __call__(self, image, target): def __call__(self, image, target):
image = F.to_tensor(image) image = F.to_tensor(image)
target = torch.as_tensor(np.asarray(target), dtype=torch.int64) target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment