"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "494de206ce029cf7d03a12eeb7d72368d04d7458"
Unverified Commit afc83aa2 authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

Graphsim (#2794)



* Add hgat example

* Add experiment

* Clean code

* clear the code

* Add index in README

* Add index in README

* Add index in README

* Add index in README

* Add index in README

* Add index in README

* Change the code title and folder name

* Ready to merge

* Prepare for rebase and change message passing function

* use git ignore to handle empty file

* change file permission to resolve empty file

* Change permission

* change file mode

* Finish Coding

* working code cpu

* pyg compare

* Accelerate with batching

* FastMode Enabled

* update readme

* Update README.md

* refractor code

* add graphsim code

* modified code

* few fix

* Modified graphsim

* Simple Model Added

* Clean up code

* Refractor the code for Merge

* Bugfix enable gradient when train

* update readme and format
Co-authored-by: default avatarChen <chesirui@3c22fbe5458c.ant.amazon.com>
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-4-63.ap-northeast-1.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-45-47.ap-northeast-1.compute.internal>
parent 4c7476c8
...@@ -88,6 +88,7 @@ The folder contains example implementations of selected research papers related ...@@ -88,6 +88,7 @@ The folder contains example implementations of selected research papers related
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | | | [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |
| [Composition-based Multi-Relational Graph Convolutional Networks](#compgcn)| | :heavy_check_mark: | | | | | [Composition-based Multi-Relational Graph Convolutional Networks](#compgcn)| | :heavy_check_mark: | | | |
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | | | [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
| [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | |
| [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | | | [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | |
## 2021 ## 2021
...@@ -133,6 +134,10 @@ The folder contains example implementations of selected research papers related ...@@ -133,6 +134,10 @@ The folder contains example implementations of selected research papers related
- Example code: [Pytorch](../examples/pytorch/tgn) - Example code: [Pytorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification - Tags: over-smoothing, node classification
- <a name="dagnn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637).
- Example code: [Pytorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification
- <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082). - <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).
- Example code: [Pytorch](../examples/pytorch/compGCN) - Example code: [Pytorch](../examples/pytorch/compGCN)
- Tags: multi-relational graphs, graph neural network - Tags: multi-relational graphs, graph neural network
......
# GraphParticleSim
## DGL Implementation of Interaction-Network paper.
This DGL example implements the GNN model proposed in the paper [Interaction Network](https://arxiv.org/abs/1612.00222.pdf).
GraphParticleSim implementor
----------------------
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.
The graph dataset used in this example
---------------------------------------
This Example uses Datasets Generate By Physics N-Body Simulator adapted from [This Repo](https://github.com/jsikyoon/Interaction-networks_tensorflow)
n_body:
- n Particles/Nodes
- Complete Bidirectional Graph
- 10 trajectories should be generated
- 1000 steps of simulation per trajectory
Dependency
--------------------------------
- ffmpeg 4.3.8
- opencv-python 4.2.0
How to run example files
--------------------------------
In the graphsim folder, run
**Please first run `n_body_sim.py` to generate some data**
Using Ground Truth Velocity From Simulator Directly.
```python
python n_body_sim.py
```
Generate Longer trajectory or more trajectories.
```python
python n_body_sim.py --num_traj <num_traj> --steps <num_steps>
```
**Please use `train.py`**
```python
python train.py --number_workers 15
```
Training with GPU
```python
python train.py --gpu 0 --number_workers 15
```
Training with visualization: for valid visualization, it might take full 40000 epoch of training
```python
python train.py --gpu 0 --number_workers 15 --visualize
```
One Step Loss Performance, Loss of test data after 40000 training epochs.
-------------------------
| Models/Dataset | 6 Body |
| :-------------- | -----: |
| Interaction Network in DGL | 80(10) |
| Interaction Network in Tensorflow | 60 |
-------------------------
Notice that The datasets are generated directly from simulator to prevent using Tensorflow to handle the original dataset. The training is very unstable, the even if the minimum loss is achieved from time to time, there are chances that loss will suddenly increase,in both auther's model and our model. Since the original model hasn't been released, the implementation of this model refers to Tensorflow version implemented in: https://github.com/jsikyoon/Interaction-networks_tensorflow which had consulted the first author for some implementation details.
import os
import copy
import numpy as np
import torch
import dgl
import networkx as nx
from torch.utils.data import Dataset, DataLoader
def build_dense_graph(n_particles):
g = nx.complete_graph(n_particles)
return dgl.from_networkx(g)
class MultiBodyDataset(Dataset):
def __init__(self, path):
self.path = path
self.zipfile = np.load(self.path)
self.node_state = self.zipfile['data']
self.node_label = self.zipfile['label']
self.n_particles = self.zipfile['n_particles']
def __len__(self):
return self.node_state.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
node_state = self.node_state[idx, :, :]
node_label = self.node_label[idx, :, :]
return (node_state, node_label)
class MultiBodyTrainDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'):
super(MultiBodyTrainDataset, self).__init__(
data_path+'n_body_train.npz')
self.stat_median = self.zipfile['median']
self.stat_max = self.zipfile['max']
self.stat_min = self.zipfile['min']
class MultiBodyValidDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'):
super(MultiBodyValidDataset, self).__init__(
data_path+'n_body_valid.npz')
class MultiBodyTestDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'):
super(MultiBodyTestDataset, self).__init__(data_path+'n_body_test.npz')
self.test_traj = self.zipfile['test_traj']
self.first_frame = torch.from_numpy(self.zipfile['first_frame'])
# Construct fully connected graph
class MultiBodyGraphCollator:
def __init__(self, n_particles):
self.n_particles = n_particles
self.graph = dgl.from_networkx(nx.complete_graph(self.n_particles))
def __call__(self, batch):
graph_list = []
data_list = []
label_list = []
for frame in batch:
graph_list.append(copy.deepcopy(self.graph))
data_list.append(torch.from_numpy(frame[0]))
label_list.append(torch.from_numpy(frame[1]))
graph_batch = dgl.batch(graph_list)
data_batch = torch.vstack(data_list)
label_batch = torch.vstack(label_list)
return graph_batch, data_batch, label_batch
import dgl
import torch
import torch.nn as nn
from torch.nn import functional as F
import dgl.nn as dglnn
import dgl.function as fn
import copy
from functools import partial
class MLP(nn.Module):
def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):
super(MLP, self).__init__()
self.layers = nn.ModuleList()
layer = nn.Linear(hidden, out_feats)
nn.init.normal_(layer.weight, std=0.1)
nn.init.zeros_(layer.bias)
self.layers.append(nn.Linear(in_feats, hidden))
if num_layers > 2:
for i in range(1, num_layers-1):
layer = nn.Linear(hidden, hidden)
nn.init.normal_(layer.weight, std=0.1)
nn.init.zeros_(layer.bias)
self.layers.append(layer)
layer = nn.Linear(hidden, out_feats)
nn.init.normal_(layer.weight, std=0.1)
nn.init.zeros_(layer.bias)
self.layers.append(layer)
def forward(self, x):
for l in range(len(self.layers)-1):
x = self.layers[l](x)
x = F.relu(x)
x = self.layers[-1](x)
return x
class PrepareLayer(nn.Module):
'''
Generate edge feature for the model input preparation:
as well as do the normalization work.
Parameters
==========
node_feats : int
Number of node features
stat : dict
dictionary which represent the statistics needed for normalization
'''
def __init__(self, node_feats, stat):
super(PrepareLayer, self).__init__()
self.node_feats = node_feats
# stat {'median':median,'max':max,'min':min}
self.stat = stat
def normalize_input(self, node_feature):
return (node_feature-self.stat['median'])*(2/(self.stat['max']-self.stat['min']))
def forward(self, g, node_feature):
with g.local_scope():
node_feature = self.normalize_input(node_feature)
g.ndata['feat'] = node_feature # Only dynamic feature
g.apply_edges(fn.u_sub_v('feat', 'feat', 'e'))
edge_feature = g.edata['e']
return node_feature, edge_feature
class InteractionNet(nn.Module):
'''
Simple Interaction Network
One Layer interaction network for stellar multi-body problem simulation,
it has the ability to simulate number of body motion no more than 12
Parameters
==========
node_feats : int
Number of node features
stat : dict
Statistcics for Denormalization
'''
def __init__(self, node_feats, stat):
super(InteractionNet, self).__init__()
self.node_feats = node_feats
self.stat = stat
edge_fn = partial(MLP, num_layers=5, hidden=150)
node_fn = partial(MLP, num_layers=2, hidden=100)
self.in_layer = InteractionLayer(node_feats-3, # Use velocity only
node_feats,
out_node_feats=2,
out_edge_feats=50,
edge_fn=edge_fn,
node_fn=node_fn,
mode='n_n')
# Denormalize Velocity only
def denormalize_output(self, out):
return out*(self.stat['max'][3:5]-self.stat['min'][3:5])/2+self.stat['median'][3:5]
def forward(self, g, n_feat, e_feat, global_feats, relation_feats):
with g.local_scope():
out_n, out_e = self.in_layer(
g, n_feat, e_feat, global_feats, relation_feats)
out_n = self.denormalize_output(out_n)
return out_n, out_e
class InteractionLayer(nn.Module):
'''
Implementation of single layer of interaction network
Parameters
==========
in_node_feats : int
Number of node features
in_edge_feats : int
Number of edge features
out_node_feats : int
Number of node feature after one interaction
out_edge_feats : int
Number of edge features after one interaction
global_feats : int
Number of global features used as input
relate_feats : int
Feature related to the relation between object themselves
edge_fn : torch.nn.Module
Function to update edge feature in message generation
node_fn : torch.nn.Module
Function to update node feature in message aggregation
mode : str
Type of message should the edge carry
nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.
n_n : [src_feat-edge_feat] node feature subtract from each other.
'''
def __init__(self, in_node_feats,
in_edge_feats,
out_node_feats,
out_edge_feats,
global_feats=1,
relate_feats=1,
edge_fn=nn.Linear,
node_fn=nn.Linear,
mode='nne'): # 'n_n'
super(InteractionLayer, self).__init__()
self.in_node_feats = in_node_feats
self.in_edge_feats = in_edge_feats
self.out_edge_feats = out_edge_feats
self.out_node_feats = out_node_feats
self.mode = mode
# MLP for message passing
input_shape = 2*self.in_node_feats + \
self.in_edge_feats if mode == 'nne' else self.in_edge_feats+relate_feats
self.edge_fn = edge_fn(input_shape,
self.out_edge_feats) # 50 in IN paper
self.node_fn = node_fn(self.in_node_feats+self.out_edge_feats+global_feats,
self.out_node_feats)
# Should be done by apply edge
def update_edge_fn(self, edges):
x = torch.cat([edges.src['feat'], edges.dst['feat'],
edges.data['feat']], dim=1)
ret = F.relu(self.edge_fn(
x)) if self.mode == 'nne' else self.edge_fn(x)
return {'e': ret}
# Assume agg comes from build in reduce
def update_node_fn(self, nodes):
x = torch.cat([nodes.data['feat'], nodes.data['agg']], dim=1)
ret = F.relu(self.node_fn(
x)) if self.mode == 'nne' else self.node_fn(x)
return {'n': ret}
def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):
# print(node_feats.shape,global_feats.shape)
g.ndata['feat'] = torch.cat([node_feats, global_feats], dim=1)
g.edata['feat'] = torch.cat([edge_feats, relation_feats], dim=1)
if self.mode == 'nne':
g.apply_edges(self.update_edge_fn)
else:
g.edata['e'] = self.edge_fn(g.edata['feat'])
g.update_all(fn.copy_e('e', 'msg'),
fn.sum('msg', 'agg'),
self.update_node_fn)
return g.ndata['n'], g.edata['e']
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import numpy as np
from math import sin, cos, radians, pi
import argparse
'''
This adapted from comes from https://github.com/jsikyoon/Interaction-networks_tensorflow
which generates multi-body dynamic simulation data for Interaction network
'''
# 5 features on the state [mass,x,y,x_vel,y_vel]
fea_num = 5
# G stand for Gravity constant 10**5 can help numerical stability
G = 10**5
# time step
diff_t = 0.001
def init(total_state, n_body, fea_num, orbit):
data = np.zeros((total_state, n_body, fea_num), dtype=float)
if(orbit):
data[0][0][0] = 100
data[0][0][1:5] = 0.0
# The position are initialized randomly.
for i in range(1, n_body):
data[0][i][0] = np.random.rand()*8.98+0.02
distance = np.random.rand()*90.0+10.0
theta = np.random.rand()*360
theta_rad = pi/2 - radians(theta)
data[0][i][1] = distance*cos(theta_rad)
data[0][i][2] = distance*sin(theta_rad)
data[0][i][3] = -1*data[0][i][2]/norm(data[0][i][1:3])*(
G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000
data[0][i][4] = data[0][i][1]/norm(data[0][i][1:3])*(
G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000
else:
for i in range(n_body):
data[0][i][0] = np.random.rand()*8.98+0.02
distance = np.random.rand()*90.0+10.0
theta = np.random.rand()*360
theta_rad = pi/2 - radians(theta)
data[0][i][1] = distance*cos(theta_rad)
data[0][i][2] = distance*sin(theta_rad)
data[0][i][3] = np.random.rand()*6.0-3.0
data[0][i][4] = np.random.rand()*6.0-3.0
return data
def norm(x):
return np.sqrt(np.sum(x**2))
def get_f(reciever, sender):
diff = sender[1:3]-reciever[1:3]
distance = norm(diff)
if(distance < 1):
distance = 1
return G*reciever[0]*sender[0]/(distance**3)*diff
# Compute stat according to the paper for normalization
def compute_stats(train_curr):
data = np.vstack(train_curr).reshape(-1, fea_num)
stat_median = np.median(data, axis=0)
stat_max = np.quantile(data, 0.95, axis=0)
stat_min = np.quantile(data, 0.05, axis=0)
return stat_median, stat_max, stat_min
def calc(cur_state, n_body):
next_state = np.zeros((n_body, fea_num), dtype=float)
f_mat = np.zeros((n_body, n_body, 2), dtype=float)
f_sum = np.zeros((n_body, 2), dtype=float)
acc = np.zeros((n_body, 2), dtype=float)
for i in range(n_body):
for j in range(i+1, n_body):
if(j != i):
f = get_f(cur_state[i][:3], cur_state[j][:3])
f_mat[i, j] += f
f_mat[j, i] -= f
f_sum[i] = np.sum(f_mat[i], axis=0)
acc[i] = f_sum[i]/cur_state[i][0]
next_state[i][0] = cur_state[i][0]
next_state[i][3:5] = cur_state[i][3:5]+acc[i]*diff_t
next_state[i][1:3] = cur_state[i][1:3]+next_state[i][3:5]*diff_t
return next_state
# The state is [mass,pos_x,pos_y,vel_x,vel_y]* n_body
def gen(n_body, num_steps, orbit):
# initialization on just first state
data = init(num_steps, n_body, fea_num, orbit)
for i in range(1, num_steps):
data[i] = calc(data[i-1], n_body)
return data
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--num_bodies', type=int, default=6)
argparser.add_argument('--num_traj', type=int, default=10)
argparser.add_argument('--steps', type=int, default=1000)
argparser.add_argument('--data_path', type=str, default='data')
args = argparser.parse_args()
if not os.path.exists(args.data_path):
os.mkdir(args.data_path)
# Generate data
data_curr = []
data_next = []
for i in range(args.num_traj):
raw_traj = gen(args.num_bodies, args.steps, True)
data_curr.append(raw_traj[:-1])
data_next.append(raw_traj[1:])
print("Train Traj: ", i)
# Compute normalization statistic from data
stat_median, stat_max, stat_min = compute_stats(data_curr)
data = np.vstack(data_curr)
label = np.vstack(data_next)[:, :, 3:5]
shuffle_idx = np.arange(data.shape[0])
np.random.shuffle(shuffle_idx)
train_split = int(0.9*data.shape[0])
valid_split = train_split+300
data = data[shuffle_idx]
label = label[shuffle_idx]
train_data = data[:train_split]
train_label = label[:train_split]
valid_data = data[train_split:valid_split]
valid_label = label[train_split:valid_split]
test_data = data[valid_split:]
test_label = label[valid_split:]
np.savez(args.data_path+'/n_body_train.npz',
data=train_data,
label=train_label,
n_particles=args.num_bodies,
median=stat_median,
max=stat_max,
min=stat_min)
np.savez(args.data_path+'/n_body_valid.npz',
data=valid_data,
label=valid_label,
n_particles=args.num_bodies)
test_traj = gen(args.num_bodies, args.steps, True)
np.savez(args.data_path+'/n_body_test.npz',
data=test_data,
label=test_label,
n_particles=args.num_bodies,
first_frame=test_traj[0],
test_traj=test_traj)
import time
import argparse
import traceback
import numpy as np
import torch
from torch.utils.data import DataLoader
import networkx as nx
import dgl
from models import MLP, InteractionNet, PrepareLayer
from dataloader import MultiBodyGraphCollator, MultiBodyTrainDataset,\
MultiBodyValidDataset, MultiBodyTestDataset
from utils import make_video
def train(optimizer, loss_fn,reg_fn, model, prep, dataloader, lambda_reg, device):
total_loss = 0
model.train()
for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
graph_batch = graph_batch.to(device)
data_batch = data_batch.to(device)
label_batch = label_batch.to(device)
optimizer.zero_grad()
node_feat, edge_feat = prep(graph_batch, data_batch)
dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
v_pred, out_e = model(graph_batch, node_feat[:, 3:5].float(
), edge_feat.float(), dummy_global, dummy_relation)
loss = loss_fn(v_pred, label_batch)
total_loss += float(loss)
zero_target = torch.zeros_like(out_e)
loss = loss + lambda_reg*reg_fn(out_e, zero_target)
reg_loss = 0
for param in model.parameters():
reg_loss = reg_loss + lambda_reg * \
reg_fn(param, torch.zeros_like(
param).float().to(device))
loss = loss + reg_loss
loss.backward()
optimizer.step()
return total_loss/(i+1)
# One step evaluation
def eval(loss_fn, model, prep, dataloader, device):
total_loss = 0
model.eval()
for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
graph_batch = graph_batch.to(device)
data_batch = data_batch.to(device)
label_batch = label_batch.to(device)
node_feat, edge_feat = prep(graph_batch, data_batch)
dummy_relation = torch.zeros(
edge_feat.shape[0], 1).float().to(device)
dummy_global = torch.zeros(
node_feat.shape[0], 1).float().to(device)
v_pred, _ = model(graph_batch, node_feat[:, 3:5].float(
), edge_feat.float(), dummy_global, dummy_relation)
loss = loss_fn(v_pred, label_batch)
total_loss += float(loss)
return total_loss/(i+1)
# Rollout Evaluation based in initial state
# Need to integrate
def eval_rollout(model, prep, initial_frame, n_object, device):
current_frame = initial_frame.to(device)
base_graph = nx.complete_graph(n_object)
graph = dgl.from_networkx(base_graph).to(device)
pos_buffer = []
model.eval()
for step in range(100):
node_feats, edge_feats = prep(graph, current_frame)
dummy_relation = torch.zeros(
edge_feats.shape[0], 1).float().to(device)
dummy_global = torch.zeros(
node_feats.shape[0], 1).float().to(device)
v_pred, _ = model(graph, node_feats[:, 3:5].float(
), edge_feats.float(), dummy_global, dummy_relation)
current_frame[:, [1, 2]] += v_pred*0.001
current_frame[:, 3:5] = v_pred
pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())
pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)
make_video(pos_buffer, 'video_model.mp4')
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--lr', type=float, default=0.001,
help='learning rate')
argparser.add_argument('--epochs', type=int, default=40000,
help='Number of epochs in training')
argparser.add_argument('--lambda_reg', type=float, default=0.001,
help='regularization weight')
argparser.add_argument('--gpu', type=int, default=-1,
help='gpu device code, -1 means cpu')
argparser.add_argument('--batch_size', type=int, default=100,
help='size of each mini batch')
argparser.add_argument('--num_workers', type=int, default=0,
help='number of workers for dataloading')
argparser.add_argument('--visualize', action='store_true', default=False,
help='Whether enable trajectory rollout mode for visualization')
args = argparser.parse_args()
# Select Device to be CPU or GPU
if args.gpu != -1:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device('cpu')
train_data = MultiBodyTrainDataset()
valid_data = MultiBodyValidDataset()
test_data = MultiBodyTestDataset()
collator = MultiBodyGraphCollator(train_data.n_particles)
train_dataloader = DataLoader(
train_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers)
valid_dataloader = DataLoader(
valid_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers)
test_full_dataloader = DataLoader(
test_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers)
node_feats = 5
stat = {'median': torch.from_numpy(train_data.stat_median).to(device),
'max': torch.from_numpy(train_data.stat_max).to(device),
'min': torch.from_numpy(train_data.stat_min).to(device)}
print("Weight: ", train_data.stat_median[0],
train_data.stat_max[0], train_data.stat_min[0])
print("Position: ", train_data.stat_median[[
1, 2]], train_data.stat_max[[1, 2]], train_data.stat_min[[1, 2]])
print("Velocity: ", train_data.stat_median[[
3, 4]], train_data.stat_max[[3, 4]], train_data.stat_min[[3, 4]])
prepare_layer = PrepareLayer(node_feats, stat).to(device)
interaction_net = InteractionNet(node_feats, stat).to(device)
print(interaction_net)
optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr)
state_dict = interaction_net.state_dict()
loss_fn = torch.nn.MSELoss()
reg_fn = torch.nn.MSELoss(reduction='sum')
try:
for e in range(args.epochs):
last_t = time.time()
loss = train(optimizer, loss_fn,reg_fn, interaction_net,
prepare_layer, train_dataloader, args.lambda_reg, device)
print("Epoch time: ", time.time()-last_t)
if e % 1 == 0:
valid_loss = eval(loss_fn, interaction_net,
prepare_layer, valid_dataloader, device)
test_full_loss = eval(
loss_fn, interaction_net, prepare_layer, test_full_dataloader, device)
print("Epoch: {}.Loss: Valid: {} Full: {}".format(
e, valid_loss, test_full_loss))
except:
traceback.print_exc()
finally:
if args.visualize:
eval_rollout(interaction_net, prepare_layer,
test_data.first_frame, test_data.n_particles, device)
make_video(test_data.test_traj[:100, :, [1, 2]], 'video_truth.mp4')
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import os
import matplotlib.animation as manimation
matplotlib.use('agg')
# Make video can be used to visualize test data
def make_video(xy, filename):
os.system("rm -rf pics/*")
FFMpegWriter = manimation.writers['ffmpeg']
metadata = dict(title='Movie Test', artist='Matplotlib',
comment='Movie support!')
writer = FFMpegWriter(fps=15, metadata=metadata)
fig = plt.figure()
plt.xlim(-200, 200)
plt.ylim(-200, 200)
fig_num = len(xy)
color = ['ro', 'bo', 'go', 'ko', 'yo', 'mo', 'co']
with writer.saving(fig, filename, len(xy)):
for i in range(len(xy)):
for j in range(len(xy[0])):
plt.plot(xy[i, j, 1], xy[i, j, 0], color[j % len(color)])
writer.grab_frame()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment