Unverified Commit ce461dae authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

V1.0.0 (#156)

* v1.0
parent c2cb2aab
import encoding
import shutil
encoding.models.get_model_file('deepten_minc', root='./')
shutil.move('deepten_minc-2e22611a.pth', 'deepten_minc.pth')
......@@ -20,6 +20,8 @@ class Options():
# model params
parser.add_argument('--model', type=str, default='densenet',
help='network model type (default: densenet)')
parser.add_argument('--pretrained', action='store_true',
default=False, help='load pretrianed mode')
parser.add_argument('--nclass', type=int, default=10, metavar='N',
help='number of classes (default: 10)')
parser.add_argument('--widen', type=int, default=4, metavar='N',
......@@ -36,7 +38,9 @@ class Options():
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 600)')
parser.add_argument('--start_epoch', type=int, default=1,
metavar='N', help='the epoch number to start (default: 0)')
metavar='N', help='the epoch number to start (default: 1)')
parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads')
# lr setting
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
......@@ -47,8 +51,8 @@ class Options():
# optimizer
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=5e-4,
metavar ='M', help='SGD weight decay (default: 5e-4)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
metavar ='M', help='SGD weight decay (default: 1e-4)')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true',
default=False, help='disables CUDA training')
......
......@@ -44,10 +44,10 @@ class Options():
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)')
parser.add_argument('--batch-size', type=int, default=None,
parser.add_argument('--batch-size', type=int, default=16,
metavar='N', help='input batch size for \
training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=None,
parser.add_argument('--test-batch-size', type=int, default=16,
metavar='N', help='input batch size for \
testing (default: same as batch size)')
# optimizer params
......@@ -77,6 +77,8 @@ class Options():
# evaluation option
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU')
parser.add_argument('--test-val', action='store_true', default= False,
help='generate masks on val set')
parser.add_argument('--no-val', action='store_true', default= False,
help='skip validation during training')
# test option
......@@ -92,25 +94,21 @@ class Options():
if args.epochs is None:
epoches = {
'coco': 30,
'citys': 240,
'pascal_aug': 80,
'pascal_voc': 50,
'pascal_aug': 50,
'pcontext': 80,
'ade20k': 120,
'ade20k': 180,
'citys': 240,
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
args.batch_size = 16
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.lr is None:
lrs = {
'coco': 0.01,
'citys': 0.01,
'pascal_voc': 0.0001,
'coco': 0.004,
'pascal_aug': 0.001,
'pascal_voc': 0.0001,
'pcontext': 0.001,
'ade20k': 0.01,
'ade20k': 0.004,
'citys': 0.004,
}
args.lr = lrs[args.dataset.lower()] / 16 * args.batch_size
print(args)
......
......@@ -14,7 +14,7 @@ import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather
import encoding.utils as utils
from encoding.nn import SegmentationLosses, BatchNorm2d
from encoding.nn import SegmentationLosses, SyncBatchNorm
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset, test_batchify_fn
from encoding.models import get_model, get_segmentation_model, MultiEvalModule
......@@ -34,6 +34,9 @@ def test(args):
if args.eval:
testset = get_segmentation_dataset(args.dataset, split='val', mode='testval',
transform=input_transform)
elif args.test_val:
testset = get_segmentation_dataset(args.dataset, split='val', mode='test',
transform=input_transform)
else:
testset = get_segmentation_dataset(args.dataset, split='test', mode='test',
transform=input_transform)
......@@ -46,10 +49,12 @@ def test(args):
# model
if args.model_zoo is not None:
model = get_model(args.model_zoo, pretrained=True)
#model.base_size = args.base_size
#model.crop_size = args.crop_size
else:
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d,
se_loss = args.se_loss, norm_layer = SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size)
# resuming checkpoint
if args.resume is None or not os.path.isfile(args.resume):
......@@ -60,8 +65,8 @@ def test(args):
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
print(model)
scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
evaluator.eval()
metric = utils.SegmentationMetric(testset.num_class)
......
......@@ -7,6 +7,7 @@ from torch.autograd import Variable
if __name__ == "__main__":
args = Options().parse()
model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux,
backbone=args.backbone,
se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d)
print('Creating the model:')
......
......@@ -15,9 +15,9 @@ import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather
import encoding.utils as utils
from encoding.nn import SegmentationLosses, BatchNorm2d
from encoding.nn import SegmentationLosses, SyncBatchNorm, OHEMSegmentationLosses
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset
from encoding.datasets import get_dataset
from encoding.models import get_segmentation_model
from option import Options
......@@ -36,9 +36,9 @@ class Trainer():
# dataset
data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
'crop_size': args.crop_size}
trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train',
trainset = get_dataset(args.dataset, split=args.train_split, mode='train',
**data_kwargs)
testset = get_segmentation_dataset(args.dataset, split='val', mode ='val',
testset = get_dataset(args.dataset, split='val', mode ='val',
**data_kwargs)
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
......@@ -51,7 +51,7 @@ class Trainer():
# model
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d,
se_loss = args.se_loss, norm_layer = SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size)
print(model)
# optimizer using different LR
......@@ -63,7 +63,8 @@ class Trainer():
optimizer = torch.optim.SGD(params_list, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
# criterions
self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,
self.criterion = SegmentationLosses(se_loss=args.se_loss,
aux=args.aux,
nclass=self.nclass,
se_weight=args.se_weight,
aux_weight=args.aux_weight)
......@@ -174,6 +175,9 @@ if __name__ == "__main__":
trainer = Trainer(args)
print('Starting Epoch:', trainer.args.start_epoch)
print('Total Epoches:', trainer.args.epochs)
if args.eval:
trainer.validation(trainer.args.start_epoch)
else:
for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val:
......
"""Prepare ADE20K dataset"""
import os
import shutil
import argparse
import zipfile
from encoding.utils import download, mkdir, check_sha1
_TARGET_DIR = os.path.expanduser('~/.encoding/data')
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize ADE20K dataset.',
epilog='Example: python prepare_cityscapes.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
args = parser.parse_args()
return args
def download_city(path, overwrite=False):
_CITY_DOWNLOAD_URLS = [
#('gtCoarse.zip', '61f23198bfff5286e0d7e316ad5c4dbbaaf4717a'),
('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
download_dir = os.path.join(path, 'downloads')
mkdir(download_dir)
for filename, checksum in _CITY_DOWNLOAD_URLS:
if not check_sha1(filename, checksum):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(filename))
# extract
with zipfile.ZipFile(filename,"r") as zip_ref:
zip_ref.extractall(path=path)
print("Extracted", filename)
if __name__ == '__main__':
args = parse_args()
mkdir(os.path.expanduser('~/.encoding/data'))
mkdir(os.path.expanduser('~/.encoding/data/cityscapes'))
if args.download_dir is not None:
if os.path.isdir(_TARGET_DIR):
os.remove(_TARGET_DIR)
# make symlink
os.symlink(args.download_dir, _TARGET_DIR)
else:
download_city(_TARGET_DIR, overwrite=False)
......@@ -20,21 +20,28 @@ def download_coco(path, overwrite=False):
_DOWNLOAD_URLS = [
('http://images.cocodataset.org/zips/train2017.zip',
'10ad623668ab00c62c096f0ed636d6aff41faca5'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
('http://images.cocodataset.org/zips/val2017.zip',
'4950dc9d00dbe1c933ee0170f5797584351d2a41'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
#('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
# '46cdcf715b6b4f67e980b529534e79c2edffe084'),
#('http://images.cocodataset.org/zips/test2017.zip',
# '99813c02442f3c112d491ea6f30cecf421d0e6b3'),
('https://hangzh.s3.amazonaws.com/encoding/data/coco/train_ids.pth',
'12cd266f97c8d9ea86e15a11f11bcb5faba700b6'),
('https://hangzh.s3.amazonaws.com/encoding/data/coco/val_ids.pth',
'4ce037ac33cbf3712fd93280a1c5e92dae3136bb'),
]
mkdir(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
if os.path.splitext(filename)[1] == '.zip':
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
else:
shutil.move(filename, os.path.join(path, 'annotations/'+os.path.basename(filename)))
def install_coco_api():
......
"""Prepare the ImageNet dataset"""
import os
import argparse
import tarfile
import pickle
import gzip
import subprocess
from tqdm import tqdm
from encoding.utils import check_sha1, download, mkdir
_TARGET_DIR = os.path.expanduser('~/.encoding/datasets/imagenet')
_TRAIN_TAR = 'ILSVRC2012_img_train.tar'
_TRAIN_TAR_SHA1 = '43eda4fe35c1705d6606a6a7a633bc965d194284'
_VAL_TAR = 'ILSVRC2012_img_val.tar'
_VAL_TAR_SHA1 = '5f3f73da3395154b60528b2b2a2caf2374f5f178'
def parse_args():
parser = argparse.ArgumentParser(
description='Setup the ImageNet dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', required=True,
help="The directory that contains downloaded tar files")
parser.add_argument('--target-dir', default=_TARGET_DIR,
help="The directory to store extracted images")
parser.add_argument('--checksum', action='store_true',
help="If check integrity before extracting.")
parser.add_argument('--with-rec', action='store_true',
help="If build image record files.")
parser.add_argument('--num-thread', type=int, default=1,
help="Number of threads to use when building image record file.")
args = parser.parse_args()
return args
def check_file(filename, checksum, sha1):
if not os.path.exists(filename):
raise ValueError('File not found: '+filename)
if checksum and not check_sha1(filename, sha1):
raise ValueError('Corrupted file: '+filename)
def build_rec_process(img_dir, train=False, num_thread=1):
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
mkdir(rec_dir)
prefix = 'train' if train else 'val'
print('Building ImageRecord file for ' + prefix + ' ...')
to_path = rec_dir
# download lst file and im2rec script
script_path = os.path.join(rec_dir, 'im2rec.py')
script_url = 'https://raw.githubusercontent.com/apache/incubator-encoding/master/tools/im2rec.py'
download(script_url, script_path)
lst_path = os.path.join(rec_dir, prefix + '.lst')
lst_url = 'http://data.encoding.io/models/imagenet/resnet/' + prefix + '.lst'
download(lst_url, lst_path)
# execution
import sys
cmd = [
sys.executable,
script_path,
rec_dir,
img_dir,
'--recursive',
'--pass-through',
'--pack-label',
'--num-thread',
str(num_thread)
]
subprocess.call(cmd)
os.remove(script_path)
os.remove(lst_path)
print('ImageRecord file for ' + prefix + ' has been built!')
def extract_train(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
with tarfile.open(tar_fname) as tar:
print("Extracting "+tar_fname+"...")
# extract each class one-by-one
pbar = tqdm(total=len(tar.getnames()))
for class_tar in tar:
pbar.set_description('Extract '+class_tar.name)
tar.extract(class_tar, target_dir)
class_fname = os.path.join(target_dir, class_tar.name)
class_dir = os.path.splitext(class_fname)[0]
os.mkdir(class_dir)
with tarfile.open(class_fname) as f:
f.extractall(class_dir)
os.remove(class_fname)
pbar.update(1)
pbar.close()
if with_rec:
build_rec_process(target_dir, True, num_thread)
def extract_val(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
print('Extracting ' + tar_fname)
with tarfile.open(tar_fname) as tar:
tar.extractall(target_dir)
# build rec file before images are moved into subfolders
if with_rec:
build_rec_process(target_dir, False, num_thread)
# move images to proper subfolders
val_maps_file = os.path.join(os.path.dirname(__file__), 'imagenet_val_maps.pklz')
with gzip.open(val_maps_file, 'rb') as f:
dirs, mappings = pickle.load(f)
for d in dirs:
os.makedirs(os.path.join(target_dir, d))
for m in mappings:
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
def main():
args = parse_args()
target_dir = os.path.expanduser(args.target_dir)
if os.path.exists(target_dir):
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
download_dir = os.path.expanduser(args.download_dir)
train_tar_fname = os.path.join(download_dir, _TRAIN_TAR)
check_file(train_tar_fname, args.checksum, _TRAIN_TAR_SHA1)
val_tar_fname = os.path.join(download_dir, _VAL_TAR)
check_file(val_tar_fname, args.checksum, _VAL_TAR_SHA1)
build_rec = args.with_rec
if build_rec:
os.makedirs(os.path.join(target_dir, 'rec'))
extract_train(train_tar_fname, os.path.join(target_dir, 'train'), build_rec, args.num_thread)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, args.num_thread)
if __name__ == '__main__':
main()
......@@ -23,7 +23,6 @@ def download_ade(path, overwrite=False):
'bf9985e9f2b064752bf6bd654d89f017c76c395a'),
('https://codalabuser.blob.core.windows.net/public/trainval_merged.json',
'169325d9f7e9047537fedca7b04de4dddf10b881'),
# You can skip these if the network is slow, the dataset will automatically generate them.
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth',
'4bfb49e8c1cefe352df876c9b5434e655c9c1d07'),
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth',
......
......@@ -18,7 +18,7 @@ import setuptools.command.install
cwd = os.path.dirname(os.path.abspath(__file__))
version = '0.5.1'
version = '1.0.1'
try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=cwd).decode('ascii').strip()
......
......@@ -173,13 +173,13 @@ def test_encoding_dist_inference():
test = gradcheck(encoding.functions.encoding_dist_inference, input, eps=EPS, atol=ATOL)
print('Testing encoding_dist_inference(): {}'.format(test))
def test_sum_square():
def test_moments():
B,C,H = 2,3,4
X = Variable(torch.cuda.DoubleTensor(B,C,H).uniform_(-0.5,0.5),
requires_grad=True)
input = (X,)
test = gradcheck(encoding.functions.sum_square, input, eps=EPS, atol=ATOL)
print('Testing sum_square(): {}'.format(test))
test = gradcheck(encoding.functions.moments, input, eps=EPS, atol=ATOL)
print('Testing moments(): {}'.format(test))
def test_syncbn_func():
# generate input
......
......@@ -49,7 +49,7 @@ def testSyncBN():
def _find_bn(module):
for m in module.modules():
if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
encoding.nn.BatchNorm1d, encoding.nn.BatchNorm2d)):
encoding.nn.SyncBatchNorm)):
return m
def _syncParameters(bn1, bn2):
bn1.reset_parameters()
......@@ -70,29 +70,128 @@ def testSyncBN():
input1 = Variable(input.clone().detach(), requires_grad=True)
input2 = Variable(input.clone().detach(), requires_grad=True)
if is_train:
bn1.train()
bn2.train()
output1 = bn1(input1)
output2 = bn2(input2)
else:
bn1.eval()
bn2.eval()
with torch.no_grad():
output1 = bn1(input1)
output2 = bn2(input2)
# assert forwarding
_assert_tensor_close(input1.data, input2.data)
#_assert_tensor_close(input1.data, input2.data)
_assert_tensor_close(output1.data, output2.data)
if not is_train:
return
(output1 ** 2).sum().backward()
(output2 ** 2).sum().backward()
_assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data)
_assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data)
_assert_tensor_close(input1.grad.data, input2.grad.data)
_assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
_assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
#_assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
bn = torch.nn.BatchNorm2d(10).cuda().double()
sync_bn = encoding.nn.BatchNorm2d(10).double()
sync_bn = encoding.nn.SyncBatchNorm(10, inplace=True, sync=True).cuda().double()
sync_bn = torch.nn.DataParallel(sync_bn).cuda()
encoding.parallel.patch_replication_callback(sync_bn)
# check with unsync version
#_check_batchnorm_result(bn, sync_bn, torch.rand(2, 1, 2, 2).double(), True, cuda=True)
for i in range(10):
print(i)
_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True)
_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
#_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
def testABN():
class NormAct(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none",
slope=0.01):
super(NormAct, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
self.slope = slope
def forward(self, x):
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
y = torch.nn.functional.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return torch.nn.functional.leaky_relu_(y, self.slope)
def _check_batchnorm_result(bn1, bn2, input, is_train, cuda=False):
def _find_bn(module):
for m in module.modules():
if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
encoding.nn.SyncBatchNorm)):
return m
def _syncParameters(bn1, bn2):
bn1.reset_parameters()
bn2.reset_parameters()
if bn1.affine and bn2.affine:
bn2.weight.data.copy_(bn1.weight.data)
bn2.bias.data.copy_(bn1.bias.data)
bn2.running_mean.copy_(bn1.running_mean)
bn2.running_var.copy_(bn1.running_var)
bn1.train(mode=is_train)
bn2.train(mode=is_train)
if cuda:
input = input.cuda()
# using the same values for gamma and beta
_syncParameters(_find_bn(bn1), _find_bn(bn2))
input1 = Variable(input.clone().detach(), requires_grad=True)
input2 = Variable(input.clone().detach(), requires_grad=True)
if is_train:
bn1.train()
bn2.train()
output1 = bn1(input1)
output2 = bn2(input2)
else:
bn1.eval()
bn2.eval()
with torch.no_grad():
output1 = bn1(input1)
output2 = bn2(input2)
# assert forwarding
_assert_tensor_close(output1.data, output2.data)
if not is_train:
return
loss1 = (output1 ** 2).sum()
loss2 = (output2 ** 2).sum()
loss1.backward()
loss2.backward()
_assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data)
_assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data)
_assert_tensor_close(input1.grad.data, input2.grad.data)
_assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
bn = NormAct(10).cuda().double()
inp_abn = encoding.nn.SyncBatchNorm(10, sync=False, activation='leaky_relu', inplace=True).cuda().double()
inp_abn = torch.nn.DataParallel(inp_abn).cuda()
# check with unsync version
for i in range(10):
print(i)
_check_batchnorm_result(bn, inp_abn, torch.rand(16, 10, 16, 16).double(), True, cuda=True)
#_check_batchnorm_result(bn, inp_abn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
def test_Atten_Module():
B, C, H, W = 8, 24, 10, 10
X = Variable(torch.cuda.DoubleTensor(B,C,H,W).uniform_(-0.5,0.5),
requires_grad=True)
layer1 = encoding.nn.MultiHeadAttention(4, 24, 24, 24).double().cuda()
Y = layer1(X)
if __name__ == '__main__':
import nose
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment