Commit d0ea98be authored by VoVAllen's avatar VoVAllen Committed by Minjie Wang
Browse files

[Model] Capsule (#95)

* add capsule example

* clean code

* better naming

* better naming

* Clean Codes based on pytorch MNIST example

* Clean codes

* Add README
parent 9e9a9488
import dgl
import torch
from torch import nn
from torch.nn import functional as F
class DGLDigitCapsuleLayer(nn.Module):
def __init__(self, input_capsule_dim=8, input_capsule_num=1152, output_capsule_num=10, output_capsule_dim=16,
num_routing=3, device='cpu'
):
super(DGLDigitCapsuleLayer, self).__init__()
self.device = device
self.input_capsule_dim = input_capsule_dim
self.input_capsule_num = input_capsule_num
self.output_capsule_dim = output_capsule_dim
self.output_capsule_num = output_capsule_num
self.num_routing = num_routing
self.weight = nn.Parameter(
torch.randn(input_capsule_num, output_capsule_num, output_capsule_dim, input_capsule_dim))
self.g, self.input_nodes, self.output_nodes = self.construct_graph()
def construct_graph(self):
g = dgl.DGLGraph()
g.add_nodes(self.input_capsule_num + self.output_capsule_num)
input_nodes = list(range(self.input_capsule_num))
output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))
u, v = [], []
for i in input_nodes:
for j in output_nodes:
u.append(i)
v.append(j)
g.add_edges(u, v)
return g, input_nodes, output_nodes
def forward(self, x):
self.batch_size = x.size(0)
x = x.transpose(1, 2)
x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)
W = self.weight.expand(self.batch_size, *self.weight.size())
u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()
b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)
self.g.set_e_repr({'b_ij': b_ij.view(-1)})
self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.output_capsule_dim)})
node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,
self.output_capsule_dim).to(self.device)
self.g.set_n_repr({'h': node_features})
for i in range(self.num_routing):
self.g.update_all(self.capsule_msg, self.capsule_reduce, self.capsule_update)
self.g.update_edge(edge_func=self.update_edge)
this_layer_nodes_feature = self.g.get_n_repr()['h'][
self.input_capsule_num:self.input_capsule_num + self.output_capsule_num]
return this_layer_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
def update_edge(self, u, v, edge):
return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}
@staticmethod
def capsule_msg(src, edge):
return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}
@staticmethod
def capsule_reduce(node, msg):
b_ij_c, u_hat = msg['b_ij'], msg['u_hat']
c_i = F.softmax(b_ij_c, dim=0)
s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)
return {'h': s_j}
@staticmethod
def capsule_update(msg):
v_j = squash(msg['h'])
return {'h': v_j}
def squash(s, dim=2):
sq = torch.sum(s ** 2, dim=dim, keepdim=True)
s_std = torch.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_std)
return s
DGL implementation of Capsule Network
=====================================
This repo implements Hinton and his team's [Capsule Network](https://arxiv.org/abs/1710.09829).
Only margin loss is implemented, for simplicity to understand the DGL.
## Training& Evaluation
```bash
# Run with default config
python main.py
# Run with train and test batch size 128, and for 50 epochs
python main.py --batch-size 128 --test-batch-size 128 --epochs 50
```
\ No newline at end of file
import argparse
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from model import Net
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = model.margin_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += model.margin_loss(output, target).item() # sum up batch loss
pred = output.norm(dim=2).squeeze().max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=512, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=512, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda:1" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net(device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if __name__ == '__main__':
main()
import torch
from torch import nn
from DGLDigitCapsule import DGLDigitCapsuleLayer, squash
class Net(nn.Module):
def __init__(self, device='cpu'):
super(Net, self).__init__()
self.device = device
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,
out_channels=256,
kernel_size=9,
stride=1), nn.ReLU(inplace=True))
self.primary = PrimaryCapsuleLayer(device=device)
self.digits = DGLDigitCapsuleLayer(device=device)
def forward(self, x):
out_conv1 = self.conv1(x)
out_primary_caps = self.primary(out_conv1)
out_digit_caps = self.digits(out_primary_caps)
return out_digit_caps
def margin_loss(self, input, target):
batch_s = target.size(0)
one_hot_vec = torch.zeros(batch_s, 10).to(self.device)
for i in range(batch_s):
one_hot_vec[i, target[i]] = 1.0
batch_size = input.size(0)
v_c = torch.sqrt((input ** 2).sum(dim=2, keepdim=True))
zero = torch.zeros(1).to(self.device)
m_plus = 0.9
m_minus = 0.1
loss_lambda = 0.5
max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1) ** 2
max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1) ** 2
t_c = one_hot_vec
l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right
l_c = l_c.sum(dim=1)
return l_c.mean()
class PrimaryCapsuleLayer(nn.Module):
def __init__(self, in_channel=256, num_unit=8, device='cpu'):
super(PrimaryCapsuleLayer, self).__init__()
self.in_channel = in_channel
self.num_unit = num_unit
self.deivce = device
self.conv_units = nn.ModuleList([
nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit)
])
def forward(self, x):
unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)]
unit = torch.stack(unit, dim=1)
batch_size = x.size(0)
unit = unit.view(batch_size, 8, -1)
return squash(unit, dim=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment