Commit 1633f310 authored by Hang Zhang's avatar Hang Zhang
Browse files

ciar exp

parent de330187
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
*.swp *.swp
*.pyc *.pyc
build/ build/
encoding/build/ data/
...@@ -25,8 +25,8 @@ class aggregate(Function): ...@@ -25,8 +25,8 @@ class aggregate(Function):
def backward(self, gradE): def backward(self, gradE):
A, R = self.saved_tensors A, R = self.saved_tensors
gradA = A.clone() gradA = A.new().resize_as_(A)
gradR = R.clone() gradR = R.new().resize_as_(R)
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE, encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R) A, R)
return gradA, gradR return gradA, gradR
...@@ -36,6 +36,7 @@ class Aggregate(nn.Module): ...@@ -36,6 +36,7 @@ class Aggregate(nn.Module):
def forward(self, A, R): def forward(self, A, R):
return aggregate()(A, R) return aggregate()(A, R)
class Encoding(nn.Module): class Encoding(nn.Module):
def __init__(self, D, K): def __init__(self, D, K):
super(Encoding, self).__init__() super(Encoding, self).__init__()
...@@ -47,13 +48,19 @@ class Encoding(nn.Module): ...@@ -47,13 +48,19 @@ class Encoding(nn.Module):
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
self.codewords.data.uniform_(0.0, 0.02) std1 = 1./((self.K*self.D)**(1/2))
self.scale.data.uniform_(0.0, 0.02) std2 = 1./((self.K)**(1/2))
self.codewords.data.uniform_(-std1, std1)
self.scale.data.uniform_(-std2, std2)
def forward(self, X): def forward(self, X):
# input X is a 4D tensor # input X is a 4D tensor
assert(X.dim()==4, "Encoding Layer requries 4D featuremaps!")
assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!") assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!")
unpacked = False
if X.dim() == 3:
unpacked = True
X = X.unsqueeze(0)
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
# reshape input # reshape input
X = X.view(B,D,-1).transpose(1,2) X = X.view(B,D,-1).transpose(1,2)
...@@ -67,6 +74,9 @@ class Encoding(nn.Module): ...@@ -67,6 +74,9 @@ class Encoding(nn.Module):
A = self.softmax(A.view(B*N,K)).view(B,N,K) A = self.softmax(A.view(B*N,K)).view(B,N,K)
# aggregate # aggregate
E = aggregate()(A, R) E = aggregate()(A, R)
if unpacked:
E = E.squeeze(0)
return E return E
def __repr__(self): def __repr__(self):
......
...@@ -89,7 +89,7 @@ __global__ void Encoding_(Aggregate_Backward_kernel) ( ...@@ -89,7 +89,7 @@ __global__ void Encoding_(Aggregate_Backward_kernel) (
sum = 0; sum = 0;
for(d=0; d<D; d++) { for(d=0; d<D; d++) {
//sum += L[b][k][d].ldg() * R[b][i][k][d].ldg(); //sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
GR[b][i][k][d] = L[b][k][d] * A[b][i][k]; GR[b][i][k][d] = L[b][k][d].ldg() * A[b][i][k].ldg();
sum += L[b][k][d].ldg() * R[b][i][k][d].ldg(); sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
} }
GA[b][i][k] = sum; GA[b][i][k] = sum;
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#define THC_GENERIC_FILE "generic/encoding_generic.c" #define THC_GENERIC_FILE "generic/encoding_generic.c"
#else #else
int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A, int Encoding_(aggregate_forward)(THCTensor *E, THCTensor *A,
THCudaTensor *R) THCTensor *R)
/* /*
* Aggregate operation * Aggregate operation
*/ */
...@@ -23,8 +23,8 @@ int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A, ...@@ -23,8 +23,8 @@ int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
return 0; return 0;
} }
int Encoding_(aggregate_backward)(THCudaTensor *GA, THCudaTensor *GR, int Encoding_(aggregate_backward)(THCTensor *GA, THCTensor *GR,
THCudaTensor *L, THCudaTensor *A, THCudaTensor *R) THCTensor *L, THCTensor *A, THCTensor *R)
/* /*
* Aggregate backward operation to A * Aggregate backward operation to A
* G (dl/dR), L (dl/dE), A (assignments) * G (dl/dR), L (dl/dE), A (assignments)
......
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torchvision
import torchvision.transforms as transforms
class Dataloder():
def __init__(self, args):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
trainloader = torch.utils.data.DataLoader(trainset, batch_size=
args.batch_size, shuffle=True, **kwargs)
testloader = torch.utils.data.DataLoader(testset, batch_size=
args.batch_size, shuffle=False, **kwargs)
self.trainloader = trainloader
self.testloader = testloader
def getloader(self):
return self.trainloader, self.testloader
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from option import Options
from model.encodenet import Net
from utils import *
# global variable
best_pred = 0.0
acclist = []
def main():
# init the args
args = Options().parse()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# init dataloader
if args.dataset == 'cifar':
from dataset.cifar import Dataloder
train_loader, test_loader = Dataloder(args).getloader()
else:
raise ValueError('Unknow dataset!')
model = Net()
if args.cuda:
model.cuda()
if args.resume is not None:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_pred = checkpoint['best_pred']
acclist = checkpoint['acclist']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no resume checkpoint found at '{}'".format(args.resume))
criterion = nn.CrossEntropyLoss()
# TODO make weight_decay oen of args
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=
args.momentum, weight_decay=1e-4)
def train(epoch):
model.train()
global best_pred
train_loss, correct, total = 0,0,0
adjust_learning_rate(optimizer, epoch, best_pred, args)
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.data[0]
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
total += target.size(0)
progress_bar(batch_idx, len(train_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1),
100.*correct/total, correct, total))
def test(epoch):
model.eval()
global best_pred
global acclist
test_loss, correct, total = 0,0,0
acc = 0.0
is_best = False
# for data, target in test_loader:
for batch_idx, (data, target) in enumerate(test_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += criterion(output, target).data[0]
# get the index of the max log-probability
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
total += target.size(0)
acc = 100.*correct/total
progress_bar(batch_idx, len(test_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (test_loss/(batch_idx+1),
acc, correct, total))
# save checkpoint
acclist += [acc]
if acc > best_pred:
best_pred = acc
is_best = True
save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'best_pred': best_pred,
'acclist':acclist,
}, args=args, is_best=is_best)
# TODO add plot curve
for epoch in range(args.start_epoch, args.epochs + 1):
train(epoch)
# FIXME this is a bug somewhere not in the code
test(epoch)
if __name__ == "__main__":
main()
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
import model.mynn as nn2
from encoding import Encoding
class Net(nn.Module):
def __init__(self, num_blocks=[2,2,2,2], num_classes=10,
block=nn2.Bottleneck):
super(Net, self).__init__()
if block == nn2.Basicblock:
self.expansion = 1
else:
self.expansion = 4
self.inplanes = 64
num_planes = [64, 128, 256, 512]
strides = [1, 2, 2, 2]
model = []
# Conv_1
model += [nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True)]
# Residual units
for i in range(4):
model += [self._residual_unit(block, num_planes[i], num_blocks[i],
strides[i])]
# Last conv layer
# TODO norm layer, instance norm?
model += [nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True),
Encoding(D=512*self.expansion,K=16),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True),
nn2.View(-1, 512*self.expansion*16),
nn.Linear(512*self.expansion*16, num_classes)]
self.model = nn.Sequential(*model)
print(model)
def _residual_unit(self, block, planes, n_blocks, stride):
strides = [stride] + [1]*(n_blocks-1)
layers = []
for i in range(n_blocks):
layers += [block(self.inplanes, planes, strides[i])]
self.inplanes = self.expansion*planes
return nn.Sequential(*layers)
def forward(self, input):
return self.model(input)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
from torch.autograd import Variable
class Basicblock(nn.Module):
def __init__(self, inplanes, planes, stride=1,
norm_layer=nn.BatchNorm2d):
super(Basicblock, self).__init__()
if inplanes != planes*self.expansion or stride !=1 :
self.downsample = True
self.residual_layer = nn.Conv2d(inplanes, planes,
kernel_size=1, stride=stride)
else:
self.downsample = False
conv_block=[]
conv_block+=[norm_layer(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=1),
norm_layer(planes),
nn.ReLU(inplace=True),
nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=1),
norm_layer(planes)]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, input):
if self.downsample:
residual = self.residual_layer(input)
else:
residual = input
return residual + self.conv_block(input)
class Bottleneck(nn.Module):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
"""
def __init__(self, inplanes, planes, stride=1,norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
self.expansion = 4
if inplanes != planes*self.expansion or stride !=1 :
self.downsample = True
self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion,
kernel_size=1, stride=stride)
else:
self.downsample = False
conv_block = []
conv_block += [norm_layer(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]
conv_block += [norm_layer(planes),
nn.ReLU(inplace=True),
nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1)]
conv_block += [norm_layer(planes),
nn.ReLU(inplace=True),
nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
stride=1)]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
if self.downsample:
residual = self.residual_layer(x)
else:
residual = x
return residual + self.conv_block(x)
class View(nn.Module):
def __init__(self, *args):
super(View, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.Size):
self.size = args[0]
else:
self.size = torch.Size(args)
def forward(self, input):
return input.view(self.size)
class InstanceNormalization(nn.Module):
"""InstanceNormalization
Improves convergence of neural-style.
ref: https://arxiv.org/pdf/1607.08022.pdf
"""
def __init__(self, dim, eps=1e-5):
super(InstanceNormalization, self).__init__()
self.weight = nn.Parameter(torch.FloatTensor(dim))
self.bias = nn.Parameter(torch.FloatTensor(dim))
self.eps = eps
self._reset_parameters()
def _reset_parameters(self):
self.weight.data.uniform_()
self.bias.data.zero_()
def forward(self, x):
n = x.size(2) * x.size(3)
t = x.view(x.size(0), x.size(1), n)
mean = torch.mean(t, 2).unsqueeze(2).expand_as(x)
# Calculate the biased var. torch.var returns unbiased var
var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n))
scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0)
scale_broadcast = scale_broadcast.expand_as(x)
shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0)
shift_broadcast = shift_broadcast.expand_as(x)
out = (x - mean) / torch.sqrt(var + self.eps)
out = out * scale_broadcast + shift_broadcast
return out
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
import model.mynn as nn2
class Net(nn.Module):
def __init__(self, num_blocks=[2,2,2,2], num_classes=10,
block=nn2.Bottleneck):
super(Net, self).__init__()
if block == nn2.Basicblock:
self.expansion = 1
else:
self.expansion = 4
self.inplanes = 64
num_planes = [64, 128, 256, 512]
strides = [1, 2, 2, 2]
model = []
# Conv_1
model += [nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True)]
# Residual units
for i in range(4):
model += [self._residual_unit(block, num_planes[i], num_blocks[i],
strides[i])]
# Last conv layer
model += [nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True),
nn.AvgPool2d(4),
nn2.View(-1, self.inplanes),
nn.Linear(self.inplanes, num_classes)]
self.model = nn.Sequential(*model)
print(model)
def _residual_unit(self, block, planes, n_blocks, stride):
strides = [stride] + [1]*(n_blocks-1)
layers = []
for i in range(n_blocks):
layers += [block(self.inplanes, planes, strides[i])]
self.inplanes = self.expansion*planes
return nn.Sequential(*layers)
def forward(self, input):
return self.model(input)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import argparse
import os
class Options():
def __init__(self):
# Training settings
parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='cifar',
help='training dataset (default: cifar)')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000,
metavar='N', help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=160, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar=
'N',help='how many batches to wait before logging status')
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default='default',
help='set the checkpoint name')
self.parser = parser
def parse(self):
return self.parser.parse_args()
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import shutil
import os
import sys
import time
import math
def adjust_learning_rate(optimizer, epoch, best_pred, args):
lr = args.lr * ((0.1 ** int(epoch > 80)) * (0.1 ** int(epoch > 120)))
print('=>Epoches %i, learning rate = %.4f, previous best = %.3f%%' % (
epoch, lr, best_pred))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
"""Saves checkpoint to disk"""
directory = "runs/%s/%s/"%(args.dataset, args.checkname)
if not os.path.exists(directory):
os.makedirs(directory)
filename = directory + filename
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, directory + 'model_best.pth.tar')
# taken from https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py
TOTAL_BAR_LENGTH = 86.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
from torch.autograd import Variable
from encoding import Aggregate
from encoding import Encoding
from torch.autograd import gradcheck
# declare dims and variables
B, N, K, D = 1, 2, 3, 4
A = Variable(torch.randn(B,N,K).cuda(), requires_grad=True)
R = Variable(torch.randn(B,N,K,D).cuda(), requires_grad=True)
X = Variable(torch.randn(B,D,3,3).cuda(), requires_grad=True)
# check Aggregate operation
test = gradcheck(Aggregate(),(A, R), eps=1e-4, atol=1e-3)
print('Gradcheck of Aggreate() returns ', test)
# check Encoding operation
encoding = Encoding(D=D, K=K).cuda()
print(encoding)
E = encoding(X)
loss = E.view(B,-1).pow(2).sum()
loss.backward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment