train.py 7.61 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
import argparse
2
import time
Chen Sirui's avatar
Chen Sirui committed
3
4
import traceback

5
import networkx as nx
Chen Sirui's avatar
Chen Sirui committed
6
7
import numpy as np
import torch
8
9
10
11
12
13
14
from dataloader import (
    MultiBodyGraphCollator,
    MultiBodyTestDataset,
    MultiBodyTrainDataset,
    MultiBodyValidDataset,
)
from models import MLP, InteractionNet, PrepareLayer
Chen Sirui's avatar
Chen Sirui committed
15
from torch.utils.data import DataLoader
16
from utils import make_video
Chen Sirui's avatar
Chen Sirui committed
17

18
import dgl
Chen Sirui's avatar
Chen Sirui committed
19
20


21
22
23
def train(
    optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
):
Chen Sirui's avatar
Chen Sirui committed
24
25
26
27
28
29
30
31
32
33
    total_loss = 0
    model.train()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        optimizer.zero_grad()
        node_feat, edge_feat = prep(graph_batch, data_batch)
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
34
35
36
37
38
39
40
        v_pred, out_e = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
Chen Sirui's avatar
Chen Sirui committed
41
42
43
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
        zero_target = torch.zeros_like(out_e)
44
        loss = loss + lambda_reg * reg_fn(out_e, zero_target)
Chen Sirui's avatar
Chen Sirui committed
45
46
        reg_loss = 0
        for param in model.parameters():
47
48
49
            reg_loss = reg_loss + lambda_reg * reg_fn(
                param, torch.zeros_like(param).float().to(device)
            )
Chen Sirui's avatar
Chen Sirui committed
50
51
52
        loss = loss + reg_loss
        loss.backward()
        optimizer.step()
53
54
    return total_loss / (i + 1)

Chen Sirui's avatar
Chen Sirui committed
55
56
57
58
59
60
61
62
63
64
65
66

# One step evaluation


def eval(loss_fn, model, prep, dataloader, device):
    total_loss = 0
    model.eval()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        node_feat, edge_feat = prep(graph_batch, data_batch)
67
68
69
70
71
72
73
74
75
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
Chen Sirui's avatar
Chen Sirui committed
76
77
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
78
79
    return total_loss / (i + 1)

Chen Sirui's avatar
Chen Sirui committed
80
81
82
83
84
85
86
87
88
89
90
91
92

# Rollout Evaluation based in initial state
# Need to integrate


def eval_rollout(model, prep, initial_frame, n_object, device):
    current_frame = initial_frame.to(device)
    base_graph = nx.complete_graph(n_object)
    graph = dgl.from_networkx(base_graph).to(device)
    pos_buffer = []
    model.eval()
    for step in range(100):
        node_feats, edge_feats = prep(graph, current_frame)
93
94
95
96
97
98
99
100
101
102
        dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph,
            node_feats[:, 3:5].float(),
            edge_feats.float(),
            dummy_global,
            dummy_relation,
        )
        current_frame[:, [1, 2]] += v_pred * 0.001
Chen Sirui's avatar
Chen Sirui committed
103
104
105
        current_frame[:, 3:5] = v_pred
        pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())
    pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)
106
    make_video(pos_buffer, "video_model.mp4")
Chen Sirui's avatar
Chen Sirui committed
107
108


109
if __name__ == "__main__":
Chen Sirui's avatar
Chen Sirui committed
110
    argparser = argparse.ArgumentParser()
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    argparser.add_argument(
        "--lr", type=float, default=0.001, help="learning rate"
    )
    argparser.add_argument(
        "--epochs", type=int, default=40000, help="Number of epochs in training"
    )
    argparser.add_argument(
        "--lambda_reg", type=float, default=0.001, help="regularization weight"
    )
    argparser.add_argument(
        "--gpu", type=int, default=-1, help="gpu device code, -1 means cpu"
    )
    argparser.add_argument(
        "--batch_size", type=int, default=100, help="size of each mini batch"
    )
    argparser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers for dataloading",
    )
    argparser.add_argument(
        "--visualize",
        action="store_true",
        default=False,
        help="Whether enable trajectory rollout mode for visualization",
    )
Chen Sirui's avatar
Chen Sirui committed
138
139
140
141
    args = argparser.parse_args()

    # Select Device to be CPU or GPU
    if args.gpu != -1:
142
        device = torch.device("cuda:{}".format(args.gpu))
Chen Sirui's avatar
Chen Sirui committed
143
    else:
144
        device = torch.device("cpu")
Chen Sirui's avatar
Chen Sirui committed
145
146
147
148
149
150
151

    train_data = MultiBodyTrainDataset()
    valid_data = MultiBodyValidDataset()
    test_data = MultiBodyTestDataset()
    collator = MultiBodyGraphCollator(train_data.n_particles)

    train_dataloader = DataLoader(
152
153
154
155
156
157
        train_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
158
    valid_dataloader = DataLoader(
159
160
161
162
163
164
        valid_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
165
    test_full_dataloader = DataLoader(
166
167
168
169
170
171
        test_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
Chen Sirui's avatar
Chen Sirui committed
172
173

    node_feats = 5
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    stat = {
        "median": torch.from_numpy(train_data.stat_median).to(device),
        "max": torch.from_numpy(train_data.stat_max).to(device),
        "min": torch.from_numpy(train_data.stat_min).to(device),
    }
    print(
        "Weight: ",
        train_data.stat_median[0],
        train_data.stat_max[0],
        train_data.stat_min[0],
    )
    print(
        "Position: ",
        train_data.stat_median[[1, 2]],
        train_data.stat_max[[1, 2]],
        train_data.stat_min[[1, 2]],
    )
    print(
        "Velocity: ",
        train_data.stat_median[[3, 4]],
        train_data.stat_max[[3, 4]],
        train_data.stat_min[[3, 4]],
    )
Chen Sirui's avatar
Chen Sirui committed
197
198
199
200
201
202
203
204

    prepare_layer = PrepareLayer(node_feats, stat).to(device)
    interaction_net = InteractionNet(node_feats, stat).to(device)
    print(interaction_net)
    optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr)
    state_dict = interaction_net.state_dict()

    loss_fn = torch.nn.MSELoss()
205
    reg_fn = torch.nn.MSELoss(reduction="sum")
Chen Sirui's avatar
Chen Sirui committed
206
207
208
    try:
        for e in range(args.epochs):
            last_t = time.time()
209
210
211
212
213
214
215
216
217
218
219
            loss = train(
                optimizer,
                loss_fn,
                reg_fn,
                interaction_net,
                prepare_layer,
                train_dataloader,
                args.lambda_reg,
                device,
            )
            print("Epoch time: ", time.time() - last_t)
Chen Sirui's avatar
Chen Sirui committed
220
            if e % 1 == 0:
221
222
223
224
225
226
227
                valid_loss = eval(
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    valid_dataloader,
                    device,
                )
Chen Sirui's avatar
Chen Sirui committed
228
                test_full_loss = eval(
229
230
231
232
233
234
235
236
237
238
239
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    test_full_dataloader,
                    device,
                )
                print(
                    "Epoch: {}.Loss: Valid: {} Full: {}".format(
                        e, valid_loss, test_full_loss
                    )
                )
Chen Sirui's avatar
Chen Sirui committed
240
241
242
243
    except:
        traceback.print_exc()
    finally:
        if args.visualize:
244
245
246
247
248
249
250
251
            eval_rollout(
                interaction_net,
                prepare_layer,
                test_data.first_frame,
                test_data.n_particles,
                device,
            )
            make_video(test_data.test_traj[:100, :, [1, 2]], "video_truth.mp4")