Unverified Commit b05334f3 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

V045 (#79)

* v0.4.5

* version file

* v0.4.5

* version num
parent 16363650
......@@ -6,7 +6,7 @@ Install Package
- Clone the GitHub repo::
git clone git@github.com:zhanghang1989/PyTorch-Encoding.git
git clone https://github.com/zhanghang1989/PyTorch-Encoding
- Install PyTorch Encoding (if not yet). Please follow the installation guide `Installing PyTorch Encoding <../notes/compile.html>`_.
......@@ -27,7 +27,7 @@ Test Pre-trained Model
for example ``Encnet_ResNet50_PContext``::
python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval
# pixAcc: 0.7838, mIoU: 0.4958: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]
# pixAcc: 0.7888, mIoU: 0.5056: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]
The command for training the model can be found by clicking ``cmd`` in the table.
......@@ -37,11 +37,11 @@ Test Pre-trained Model
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Model | pixAcc | mIoU | Note | Command | Logs |
+==================================+===========+===========+===========+==============================================================================================+============+
| Encnet_ResNet50_PContext | 78.4% | 49.6% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
| Encnet_ResNet50_PContext | 78.9% | 50.6% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_PContext | 79.9% | 51.8% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
| EncNet_ResNet101_PContext | 80.3% | 53.2% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet50_ADE | 79.8% | 41.3% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
| EncNet_ResNet50_ADE | 79.9% | 41.2% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
.. _ENC50PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_pcontext.log?raw=true
......@@ -68,7 +68,7 @@ Test Pre-trained Model
</code>
<code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ade20k --model encnetv2 --aux --se-loss
</code>
Quick Demo
......
......@@ -6,20 +6,18 @@ Install from Source
-------------------
* Install PyTorch by following the `PyTorch instructions <http://pytorch.org/>`_.
This package relies on PyTorch master branch (higher than stable released v0.4.0), please follow
`the instruction <https://github.com/pytorch/pytorch#from-source>`_ to install
PyTorch from source.
* Install from source
* PIP Install::
- Clone the repo::
pip install torch-encoding
git clone https://github.com/zhanghang1989/PyTorch-Encoding && cd PyTorch-Encoding
* Install from source::
- On Linux::
python setup.py install
- On Mac OSX::
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
git clone https://github.com/zhanghang1989/PyTorch-Encoding && cd PyTorch-Encoding
python setup.py install
Citations
---------
......
Extending PyTorch-Encoding
==========================
In this note we'll discuss extending PyTorch-Encoding package,
which is extending :mod:`torch.nn` and
:mod:`torch.autograd` with custom CUDA backend.
Torch C and CUDA Backend
------------------------
Given a simple example of the residual operation (in a mini-batch):
.. math::
r_{ik} = x_i - c_k
where the inputs are :math:`X=\{x_1, ...x_N\}` and :math:`C=\{c_1,...c_k\}` and the output is :math:`R=\{r_{ik}\}`.
- Add CUDA kernel function and expose a C API to the generic file ``encoding/kernel/generic/encoding_kernel.c`` using Torch generic files::
__global__ void Encoding_(Residual_Forward_kernel) (
THCDeviceTensor<real, 4> R,
THCDeviceTensor<real, 3> X,
THCDeviceTensor<real, 2> D)
/*
* residual forward kernel function
*/
{
/* declarations of the variables */
int b, k, d, i, K;
/* Get the index and channels */
b = blockIdx.z;
d = blockIdx.x * blockDim.x + threadIdx.x;
i = blockIdx.y * blockDim.y + threadIdx.y;
K = R.getSize(2);
/* boundary check for output */
if (d >= X.getSize(2) || i >= X.getSize(1)) return;
/* main operation */
for(k=0; k<K; k++) {
R[b][i][k][d] = X[b][i][d].ldg() - D[k][d].ldg();
}
}
void Encoding_(Residual_Forward)(
THCState *state, THCTensor *R_, THCTensor *X_, THCTensor *D_)
/*
* residual forward
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 3, R_, X_, D_);
if (THCTensor_(nDimension)(state, R_) != 4 ||
THCTensor_(nDimension)(state, X_) != 3 ||
THCTensor_(nDimension)(state, D_) != 2)
THError("Encoding: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
THCDeviceTensor<real, 3> X = devicetensor<3>(state, X_);
THCDeviceTensor<real, 2> D = devicetensor<2>(state, D_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 threads(16, 16);
dim3 blocks(X.getSize(2)/16+1, X.getSize(1)/16+1,
X.getSize(0));
Encoding_(Residual_Forward_kernel)<<<blocks, threads, 0, stream>>>(R, X, D);
THCudaCheck(cudaGetLastError());
}
- Add corresponding function header to ``encoding/kernel/generic/encoding_kernel.h``::
void Encoding_(Residual_Forward)(
THCState *state, THCTensor *R_, THCTensor *X_, THCTensor *D_);
- Add a CFFI function to ``encoding/src/generic/encoding_generic.c``, which calls the C API we just write::
int Encoding_(residual_forward)(THCTensor *R, THCTensor *X, THCTensor *D)
/*
* Residual operation
*/
{
Encoding_(Residual_Forward)(state, R, X, D);
/* C function return number of the outputs */
return 0;
}
- Add corresponding function header to ``encoding/src/encoding_lib.h``::
int Encoding_Float_residual_forward(THCudaTensor *R, THCudaTensor *X,
THCudaTensor *D);
- Finally, call this function using python::
class residual(Function):
def forward(self, X, C):
# X \in(BxNxD) D \in(KxD) R \in(BxNxKxD)
B, N, D = X.size()
K = C.size(0)
with torch.cuda.device_of(X):
R = X.new(B,N,K,D)
if isinstance(X, torch.cuda.FloatTensor):
with torch.cuda.device_of(X):
encoding_lib.Encoding_Float_residual_forward(R, X, C)
elif isinstance(X, torch.cuda.DoubleTensor):
with torch.cuda.device_of(X):
encoding_lib.Encoding_Double_residual_forward(R, X, C)
else:
raise RuntimeError('Unimplemented data type!')
return R
- Note this is just an example. You also need to implement backward function for ``residual`` operation.
......@@ -10,5 +10,6 @@ datasets = {
'pascal_aug': VOCAugSegmentation,
'pcontext': ContextSegmentation,
}
def get_segmentation_dataset(name, **kwargs):
return datasets[name.lower()](**kwargs)
......@@ -21,9 +21,9 @@ class ADE20KSegmentation(BaseDataset):
BASE_DIR = 'ADEChallengeData2016'
NUM_CLASS = 150
def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
mode=None, transform=None, target_transform=None):
mode=None, transform=None, target_transform=None, **kwargs):
super(ADE20KSegmentation, self).__init__(
root, split, mode, transform, target_transform)
root, split, mode, transform, target_transform, **kwargs)
# assert exists and prepare dataset automatically
root = os.path.join(root, self.BASE_DIR)
assert os.path.exists(root), "Please setup the dataset using" + \
......@@ -70,27 +70,37 @@ class ADE20KSegmentation(BaseDataset):
def _get_ade20k_pairs(folder, split='train'):
img_paths = []
mask_paths = []
def get_path_pairs(img_folder, mask_folder):
img_paths = []
mask_paths = []
for filename in os.listdir(img_folder):
basename, _ = os.path.splitext(filename)
if filename.endswith(".jpg"):
imgpath = os.path.join(img_folder, filename)
maskname = basename + '.png'
maskpath = os.path.join(mask_folder, maskname)
if os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
print('cannot find the mask:', maskpath)
return img_paths, mask_paths
if split == 'train':
img_folder = os.path.join(folder, 'images/training')
mask_folder = os.path.join(folder, 'annotations/training')
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
elif split == 'val':
img_folder = os.path.join(folder, 'images/validation')
mask_folder = os.path.join(folder, 'annotations/validation')
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
else:
img_folder = os.path.join(folder, 'images/trainval')
mask_folder = os.path.join(folder, 'annotations/trainval')
for filename in os.listdir(img_folder):
basename, _ = os.path.splitext(filename)
if filename.endswith(".jpg"):
imgpath = os.path.join(img_folder, filename)
maskname = basename + '.png'
maskpath = os.path.join(mask_folder, maskname)
if os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
print('cannot find the mask:', maskpath)
train_img_folder = os.path.join(folder, 'images/training')
train_mask_folder = os.path.join(folder, 'annotations/training')
val_img_folder = os.path.join(folder, 'images/validation')
val_mask_folder = os.path.join(folder, 'annotations/validation')
train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
return train_img_paths + val_img_paths, train_mask_paths + val_mask_paths
return img_paths, mask_paths
......@@ -22,6 +22,9 @@ class BaseDataset(data.Dataset):
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
if self.mode == 'train':
print('BaseDataset: base_size {}, crop_size {}'. \
format(base_size, crop_size))
def __getitem__(self, index):
raise NotImplemented
......
......@@ -16,8 +16,9 @@ class VOCAugSegmentation(BaseDataset):
NUM_CLASS = 21
TRAIN_BASE_DIR = 'VOCaug/dataset/'
def __init__(self, root, split='train', mode=None, transform=None,
target_transform=None):
super(VOCAugSegmentation, self).__init__(root, split, mode, transform, target_transform)
target_transform=None, **kwargs):
super(VOCAugSegmentation, self).__init__(root, split, mode, transform,
target_transform, **kwargs)
# train/val/test splits are pre-cut
_voc_root = os.path.join(root, self.TRAIN_BASE_DIR)
_mask_dir = os.path.join(_voc_root, 'cls')
......
......@@ -17,8 +17,9 @@ class VOCSegmentation(BaseDataset):
NUM_CLASS = 21
BASE_DIR = 'VOCdevkit/VOC2012'
def __init__(self, root, split='train', mode=None, transform=None,
target_transform=None):
super(VOCSegmentation, self).__init__(root, split, mode, transform, target_transform)
target_transform=None, **kwargs):
super(VOCSegmentation, self).__init__(root, split, mode, transform,
target_transform, **kwargs)
_voc_root = os.path.join(self.root, self.BASE_DIR)
_mask_dir = os.path.join(_voc_root, 'SegmentationClass')
_image_dir = os.path.join(_voc_root, 'JPEGImages')
......
......@@ -18,9 +18,9 @@ class ContextSegmentation(BaseDataset):
BASE_DIR = 'VOCdevkit/VOC2010'
NUM_CLASS = 59
def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
mode=None, transform=None, target_transform=None):
mode=None, transform=None, target_transform=None, **kwargs):
super(ContextSegmentation, self).__init__(
root, split, mode, transform, target_transform)
root, split, mode, transform, target_transform, **kwargs)
from detail import Detail
#from detail import mask
root = os.path.join(root, self.BASE_DIR)
......@@ -78,8 +78,6 @@ class ContextSegmentation(BaseDataset):
img = self.transform(img)
return img, os.path.basename(path)
# convert mask to 60 categories
#mask = Image.fromarray(self._class_to_index(
# self.detail.getMask(img_id)))
mask = self.masks[iid]
# synchrosized transform
if self.mode == 'train':
......
......@@ -24,13 +24,16 @@ __all__ = ['BaseNet', 'MultiEvalModule']
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'):
base_size=576, crop_size=608, mean=[.485, .456, .406],
std=[.229, .224, .225], root='~/.encoding/models'):
super(BaseNet, self).__init__()
self.nclass = nclass
self.aux = aux
self.se_loss = se_loss
self.mean = mean
self.std = std
self.base_size = base_size
self.crop_size = crop_size
# copying modules from pretrained models
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
......@@ -70,15 +73,16 @@ class BaseNet(nn.Module):
class MultiEvalModule(DataParallel):
"""Multi-size Segmentation Eavluator"""
def __init__(self, module, nclass, device_ids=None,
base_size=520, crop_size=480, flip=True,
def __init__(self, module, nclass, device_ids=None, flip=True,
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
super(MultiEvalModule, self).__init__(module, device_ids)
self.nclass = nclass
self.base_size = base_size
self.crop_size = crop_size
self.base_size = module.base_size
self.crop_size = module.crop_size
self.scales = scales
self.flip = flip
print('MultiEvalModule: base_size {}, crop_size {}'. \
format(self.base_size, self.crop_size))
def parallel_forward(self, inputs, **kwargs):
"""Multi-GPU Mult-size Evaluation
......@@ -86,7 +90,8 @@ class MultiEvalModule(DataParallel):
Args:
inputs: list of Tensors
"""
inputs = [(input.unsqueeze(0).cuda(device),) for input, device in zip(inputs, self.device_ids)]
inputs = [(input.unsqueeze(0).cuda(device),)
for input, device in zip(inputs, self.device_ids)]
replicas = self.replicate(self, self.device_ids[:len(inputs)])
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
......@@ -134,8 +139,8 @@ class MultiEvalModule(DataParallel):
_,_,ph,pw = pad_img.size()
assert(ph >= height and pw >= width)
# grid forward and normalize
h_grids = int(math.ceil(1.0*(ph-crop_size)/stride)) + 1
w_grids = int(math.ceil(1.0*(pw-crop_size)/stride)) + 1
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
with torch.cuda.device_of(image):
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
......
......@@ -21,7 +21,7 @@ class EncNet(BaseNet):
norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
self.head = EncHead(2048, self.nclass, se_loss=se_loss,
lateral=lateral, norm_layer=norm_layer,
up_kwargs=self._up_kwargs)
if aux:
......@@ -43,15 +43,15 @@ class EncNet(BaseNet):
class EncModule(nn.Module):
def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None):
super(EncModule, self).__init__()
norm_layer = nn.BatchNorm1d if isinstance(norm_layer, nn.BatchNorm2d) else \
encoding.nn.BatchNorm1d
#norm_layer = nn.BatchNorm1d if isinstance(norm_layer, nn.BatchNorm2d) else \
# encoding.nn.BatchNorm1d
self.se_loss = se_loss
self.encoding = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, bias=False),
nn.BatchNorm2d(in_channels),
norm_layer(in_channels),
nn.ReLU(inplace=True),
encoding.nn.Encoding(D=in_channels, K=ncodes),
norm_layer(ncodes),
encoding.nn.BatchNorm1d(ncodes),
nn.ReLU(inplace=True),
encoding.nn.Mean(dim=1))
self.fc = nn.Sequential(
......@@ -72,7 +72,7 @@ class EncModule(nn.Module):
class EncHead(nn.Module):
def __init__(self, out_channels, in_channels, se_loss=True, lateral=True,
def __init__(self, in_channels, out_channels, se_loss=True, lateral=True,
norm_layer=None, up_kwargs=None):
super(EncHead, self).__init__()
self.se_loss = se_loss
......@@ -167,7 +167,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=True, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......@@ -186,7 +186,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=False, **kwargs)
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=True, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......
......@@ -8,14 +8,13 @@ from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
('853f2fb07aeb2927f7696e166b215609a987fd44', 'resnet50'),
#('bbba8e79b6bd131e82e2edf2ac0f119b3c6b8f87', 'resnet50'),
('5be5422ad7cb6a2e5f5a54070d0aa9affe69a9a4', 'resnet101'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('425a7b15176105be0c0ae522aefde02bdcb3b9f5', 'encnet_resnet50_pcontext'),
('abf1472fde53b7b41d7801a1f715765e1ef6f86e', 'encnet_resnet101_pcontext'),
('167f05f69df94d4066dad155d1a71dc6493747eb', 'encnet_resnet50_ade'),
('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
('558e8904e123813f23dc0347acba85224650fe5f', 'encnet_resnet50_ade'),
('7846a2f065e90ce70d268ba8ada1a92251587734', 'encnet_resnet50_pcontext'),
('6f7c372259988bc2b6d7fc0007182e7835c31a11', 'encnet_resnet101_pcontext'),
]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
......
......@@ -39,7 +39,7 @@ def softmax_crossentropy(input, target, weight, size_average, ignore_index, redu
class SegmentationLosses(CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.1, nclass=-1,
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.2, weight=None,
size_average=True, ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, size_average, ignore_index)
......
......@@ -22,7 +22,8 @@ def batch_pix_accuracy(predict, target):
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target)*(target > 0))
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
......
......@@ -12,7 +12,7 @@ import torch
import torchvision
import torchvision.transforms as transforms
class Dataloder():
class Dataloader():
def __init__(self, args):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
......
......@@ -74,7 +74,7 @@ class MINCDataloder(data.Dataset):
return len(self.images)
class Dataloder():
class Dataloader():
def __init__(self, args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
......
......@@ -82,7 +82,6 @@ def main():
model.train()
global best_pred, errlist_train
train_loss, correct, total = 0,0,0
#adjust_learning_rate(optimizer, args, epoch, best_pred)
tbar = tqdm(train_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
scheduler(optimizer, batch_idx, epoch, best_pred)
......
......@@ -25,6 +25,10 @@ class Options():
$(HOME)/data)')
parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=608,
help='base image size')
parser.add_argument('--crop-size', type=int, default=576,
help='crop image size')
# training hyper params
parser.add_argument('--aux', action='store_true', default= False,
help='Auxilary Loss')
......@@ -40,6 +44,7 @@ class Options():
parser.add_argument('--test-batch-size', type=int, default=None,
metavar='N', help='input batch size for \
testing (default: same as batch size)')
# optimizer params
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: auto)')
parser.add_argument('--lr-scheduler', type=str, default='poly',
......@@ -88,7 +93,7 @@ class Options():
'pascal_voc': 50,
'pascal_aug': 50,
'pcontext': 80,
'ade20k': 120,
'ade20k': 160,
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
......
......@@ -42,18 +42,19 @@ def test(args):
testset = get_segmentation_dataset(args.dataset, split='test', mode='test',
transform=input_transform)
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
if args.cuda else {}
test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
drop_last=False, shuffle=False,
collate_fn=test_batchify_fn, **kwargs)
collate_fn=test_batchify_fn, **loader_kwargs)
# model
if args.model_zoo is not None:
model = get_model(args.model_zoo, pretrained=True)
else:
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d)
se_loss = args.se_loss, norm_layer = BatchNorm2d,
base_size=args.base_size, crop_size=args.crop_size)
# resuming checkpoint
if args.resume is None or not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
......
......@@ -34,10 +34,12 @@ class Trainer():
transform.ToTensor(),
transform.Normalize([.485, .456, .406], [.229, .224, .225])])
# dataset
trainset = get_segmentation_dataset(args.dataset, split='train',
transform=input_transform)
testset = get_segmentation_dataset(args.dataset, split='val',
transform=input_transform)
data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
'crop_size': args.crop_size}
trainset = get_segmentation_dataset(args.dataset, split='train', mode='train',
**data_kwargs)
testset = get_segmentation_dataset(args.dataset, split='val', mode ='val',
**data_kwargs)
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
if args.cuda else {}
......@@ -49,7 +51,8 @@ class Trainer():
# model
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d)
se_loss = args.se_loss, norm_layer = BatchNorm2d,
base_size=args.base_size, crop_size=args.crop_size)
print(model)
# optimizer using different LR
params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment