train.py 7.61 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
import argparse
2
import time
Chen Sirui's avatar
Chen Sirui committed
3
4
import traceback

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
import networkx as nx
Chen Sirui's avatar
Chen Sirui committed
8
9
import numpy as np
import torch
10
11
12
13
14
15
from dataloader import (
    MultiBodyGraphCollator,
    MultiBodyTestDataset,
    MultiBodyTrainDataset,
    MultiBodyValidDataset,
)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
16
from models import InteractionNet, MLP, PrepareLayer
Chen Sirui's avatar
Chen Sirui committed
17
from torch.utils.data import DataLoader
18
from utils import make_video
Chen Sirui's avatar
Chen Sirui committed
19
20


21
22
23
def train(
    optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
):
Chen Sirui's avatar
Chen Sirui committed
24
25
26
27
28
29
30
31
32
33
    total_loss = 0
    model.train()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        optimizer.zero_grad()
        node_feat, edge_feat = prep(graph_batch, data_batch)
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
34
35
36
37
38
39
40
        v_pred, out_e = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
Chen Sirui's avatar
Chen Sirui committed
41
42
43
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
        zero_target = torch.zeros_like(out_e)
44
        loss = loss + lambda_reg * reg_fn(out_e, zero_target)
Chen Sirui's avatar
Chen Sirui committed
45
46
        reg_loss = 0
        for param in model.parameters():
47
48
49
            reg_loss = reg_loss + lambda_reg * reg_fn(
                param, torch.zeros_like(param).float().to(device)
            )
Chen Sirui's avatar
Chen Sirui committed
50
51
52
        loss = loss + reg_loss
        loss.backward()
        optimizer.step()
53
54
    return total_loss / (i + 1)

Chen Sirui's avatar
Chen Sirui committed
55
56
57
58
59
60
61
62
63
64
65
66

# One step evaluation


def eval(loss_fn, model, prep, dataloader, device):
    total_loss = 0
    model.eval()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        node_feat, edge_feat = prep(graph_batch, data_batch)
67
68
69
70
71
72
73
74
75
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
Chen Sirui's avatar
Chen Sirui committed
76
77
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
78
79
    return total_loss / (i + 1)

Chen Sirui's avatar
Chen Sirui committed
80
81
82
83
84
85
86
87
88
89
90
91
92

# Rollout Evaluation based in initial state
# Need to integrate


def eval_rollout(model, prep, initial_frame, n_object, device):
    current_frame = initial_frame.to(device)
    base_graph = nx.complete_graph(n_object)
    graph = dgl.from_networkx(base_graph).to(device)
    pos_buffer = []
    model.eval()
    for step in range(100):
        node_feats, edge_feats = prep(graph, current_frame)
93
94
95
96
97
98
99
100
101
102
        dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph,
            node_feats[:, 3:5].float(),
            edge_feats.float(),
            dummy_global,
            dummy_relation,
        )
        current_frame[:, [1, 2]] += v_pred * 0.001
Chen Sirui's avatar
Chen Sirui committed
103
104
105
        current_frame[:, 3:5] = v_pred
        pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())
    pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)
106
    make_video(pos_buffer, "video_model.mp4")
Chen Sirui's avatar
Chen Sirui committed
107
108


109
if __name__ == "__main__":
Chen Sirui's avatar
Chen Sirui committed
110
    argparser = argparse.ArgumentParser()
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    argparser.add_argument(
        "--lr", type=float, default=0.001, help="learning rate"
    )
    argparser.add_argument(
        "--epochs", type=int, default=40000, help="Number of epochs in training"
    )
    argparser.add_argument(
        "--lambda_reg", type=float, default=0.001, help="regularization weight"
    )
    argparser.add_argument(
        "--gpu", type=int, default=-1, help="gpu device code, -1 means cpu"
    )
    argparser.add_argument(
        "--batch_size", type=int, default=100, help="size of each mini batch"
    )
    argparser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers for dataloading",
    )
    argparser.add_argument(
        "--visualize",
        action="store_true",
        default=False,
        help="Whether enable trajectory rollout mode for visualization",
    )
Chen Sirui's avatar
Chen Sirui committed
138
139
140
141
    args = argparser.parse_args()

    # Select Device to be CPU or GPU
    if args.gpu != -1:
142
        device = torch.device("cuda:{}".format(args.gpu))
Chen Sirui's avatar
Chen Sirui committed
143
    else:
144
        device = torch.device("cpu")
Chen Sirui's avatar
Chen Sirui committed
145
146
147
148
149
150
151

    train_data = MultiBodyTrainDataset()
    valid_data = MultiBodyValidDataset()
    test_data = MultiBodyTestDataset()
    collator = MultiBodyGraphCollator(train_data.n_particles)

    train_dataloader = DataLoader(
152
153
154
155
156
157
        train_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
158
    valid_dataloader = DataLoader(
159
160
161
162
163
164
        valid_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
165
    test_full_dataloader = DataLoader(
166
167
168
169
170
171
        test_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
172
173

    node_feats = 5
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    stat = {
        "median": torch.from_numpy(train_data.stat_median).to(device),
        "max": torch.from_numpy(train_data.stat_max).to(device),
        "min": torch.from_numpy(train_data.stat_min).to(device),
    }
    print(
        "Weight: ",
        train_data.stat_median[0],
        train_data.stat_max[0],
        train_data.stat_min[0],
    )
    print(
        "Position: ",
        train_data.stat_median[[1, 2]],
        train_data.stat_max[[1, 2]],
        train_data.stat_min[[1, 2]],
    )
    print(
        "Velocity: ",
        train_data.stat_median[[3, 4]],
        train_data.stat_max[[3, 4]],
        train_data.stat_min[[3, 4]],
    )
Chen Sirui's avatar
Chen Sirui committed
197
198
199
200
201
202
203
204

    prepare_layer = PrepareLayer(node_feats, stat).to(device)
    interaction_net = InteractionNet(node_feats, stat).to(device)
    print(interaction_net)
    optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr)
    state_dict = interaction_net.state_dict()

    loss_fn = torch.nn.MSELoss()
205
    reg_fn = torch.nn.MSELoss(reduction="sum")
Chen Sirui's avatar
Chen Sirui committed
206
207
208
    try:
        for e in range(args.epochs):
            last_t = time.time()
209
210
211
212
213
214
215
216
217
218
219
            loss = train(
                optimizer,
                loss_fn,
                reg_fn,
                interaction_net,
                prepare_layer,
                train_dataloader,
                args.lambda_reg,
                device,
            )
            print("Epoch time: ", time.time() - last_t)
Chen Sirui's avatar
Chen Sirui committed
220
            if e % 1 == 0:
221
222
223
224
225
226
227
                valid_loss = eval(
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    valid_dataloader,
                    device,
                )
Chen Sirui's avatar
Chen Sirui committed
228
                test_full_loss = eval(
229
230
231
232
233
234
235
236
237
238
239
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    test_full_dataloader,
                    device,
                )
                print(
                    "Epoch: {}.Loss: Valid: {} Full: {}".format(
                        e, valid_loss, test_full_loss
                    )
                )
Chen Sirui's avatar
Chen Sirui committed
240
241
242
243
    except:
        traceback.print_exc()
    finally:
        if args.visualize:
244
245
246
247
248
249
250
251
            eval_rollout(
                interaction_net,
                prepare_layer,
                test_data.first_frame,
                test_data.n_particles,
                device,
            )
            make_video(test_data.test_traj[:100, :, [1, 2]], "video_truth.mp4")