train_sfno.py 17.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import time

from tqdm import tqdm
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

47
from torch_harmonics.examples.sfno import PdeDataset
48
from torch_harmonics import RealSHT
49

Boris Bonev's avatar
Boris Bonev committed
50
51
# wandb logging
import wandb
52

Boris Bonev's avatar
Boris Bonev committed
53
54
wandb.login()

55

56
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
57
    loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
58
59
    if relative:
        loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
60

Boris Bonev's avatar
Boris Bonev committed
61
62
63
64
65
66
    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

67

68
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
69
70
    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
71
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
72
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
73
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
74
75
76

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
77
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
78
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
79
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
80
81
82
83
84
85
86
87
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

88

89
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
90
91
92
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
93
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
94
95
96

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
97
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
98
99
    coeffs = spectral_weights * coeffs
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
100
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
101
102
103

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
104
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
105
106
        tar_coeffs = spectral_weights * tar_coeffs
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
107
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
108
109
110
111
112
113
114
115
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

116

117
def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
118
119
120
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
121
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
122
123
124

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
125
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
126
127
128
    h1_coeffs = spectral_weights * coeffs
    h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
    l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
129
130
    h1_loss = torch.sum(h1_norm2, dim=(-1, -2))
    l2_loss = torch.sum(l2_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
131

132
    # strictly speaking this is not exactly h1 loss
Boris Bonev's avatar
Boris Bonev committed
133
134
135
136
137
138
139
140
141
142
143
144
    if not squared:
        loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
    else:
        loss = h1_loss + l2_loss

    if relative:
        raise NotImplementedError("Relative H1 loss not implemented")

    loss = loss.mean()

    return loss

145

Boris Bonev's avatar
Boris Bonev committed
146
147
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
    # compute the weighting factor first
148
    fluct = solver.integrate_grid((tar - inp) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
149
150
    weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
    # weight = weight.reshape(*weight.shape, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
151

152
    loss = weight * solver.integrate_grid((prd - tar) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
153
154
155
156
157
    if relative:
        loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
    loss = torch.mean(loss)
    return loss

158

159
# rolls out the FNO and compares to the classical solver
160
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=50):
161
162
163
164
165
166
167

    model.eval()

    losses = np.zeros(nics)
    fno_times = np.zeros(nics)
    nwp_times = np.zeros(nics)

168
169
170
171
    # accumulation buffers for the power spectrum
    prd_mean_coeffs = []
    ref_mean_coeffs = []

172
173
174
175
176
177
178
179
180
181
182
    for iic in range(nics):
        ic = dataset.solver.random_initial_condition(mach=0.2)
        inp_mean = dataset.inp_mean
        inp_var = dataset.inp_var

        prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
        prd = prd.unsqueeze(0)
        uspec = ic.clone()

        # ML model
        start_time = time.time()
183
        for i in range(1, autoreg_steps + 1):
184
185
186
            # evaluate the ML model
            prd = model(prd)

187
            if iic == nics - 1 and nskip > 0 and i % nskip == 0:
188
189
190
191

                # do plotting
                fig = plt.figure(figsize=(7.5, 6))
                dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
192
193
                plt.savefig(path_root + "_pred_" + str(i // nskip) + ".png")
                plt.close()
194
195
196
197
198

        fno_times[iic] = time.time() - start_time

        # classical model
        start_time = time.time()
199
        for i in range(1, autoreg_steps + 1):
Boris Bonev's avatar
Boris Bonev committed
200

201
202
            # advance classical model
            uspec = dataset.solver.timestep(uspec, nsteps)
203
            ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
Boris Bonev's avatar
Boris Bonev committed
204

205
            if iic == nics - 1 and i % nskip == 0 and nskip > 0:
Boris Bonev's avatar
Boris Bonev committed
206

207
208
                fig = plt.figure(figsize=(7.5, 6))
                dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
209
210
                plt.savefig(path_root + "_truth_" + str(i // nskip) + ".png")
                plt.close()
Boris Bonev's avatar
Boris Bonev committed
211

212
        nwp_times[iic] = time.time() - start_time
Boris Bonev's avatar
Boris Bonev committed
213

214
215
216
217
218
219
        # compute power spectrum and add it to the buffers
        prd_coeffs = dataset.solver.sht(prd[0, plot_channel])
        ref_coeffs = dataset.solver.sht(ref[plot_channel])
        prd_mean_coeffs.append(prd_coeffs)
        ref_mean_coeffs.append(ref_coeffs)

220
221
222
223
        # ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
        ref = dataset.solver.spec2grid(uspec)
        prd = prd * torch.sqrt(inp_var) + inp_mean
        losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
Boris Bonev's avatar
Boris Bonev committed
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    # compute the averaged powerspectra of prediction and reference
    prd_mean_coeffs = torch.stack(prd_mean_coeffs).abs().pow(2).mean(dim=0)
    ref_mean_coeffs = torch.stack(ref_mean_coeffs).abs().pow(2).mean(dim=0)
    prd_mean_coeffs[..., 1:] *= 2.0
    ref_mean_coeffs[..., 1:] *= 2.0
    prd_mean_ps = prd_mean_coeffs.sum(dim=-1).detach().cpu()
    ref_mean_ps = ref_mean_coeffs.sum(dim=-1).detach().cpu()

    # compute the averaged powerspectrum
    fig = plt.figure(figsize=(7.5, 6))
    plt.loglog(prd_mean_ps, label="prediction")
    plt.loglog(ref_mean_ps, label="reference")
    plt.xlabel("$l$")
    plt.ylabel("powerspectrum")
    plt.legend()
    plt.savefig(path_root + "_powerspectrum.png")
    plt.close()
Boris Bonev's avatar
Boris Bonev committed
242

243
    return losses, fno_times, nwp_times
Boris Bonev's avatar
Boris Bonev committed
244

245

246
247
248
249
250
251
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
    """
    Helper routine intended for debugging purposes
    """
    root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
Boris Bonev's avatar
Boris Bonev committed
252

253
254
    weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
    print(weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
255

256
257
    weights_dict = {k: v for k, v in model.named_parameters()}
    grad_dict = {k: v.grad for k, v in model.named_parameters()}
Boris Bonev's avatar
Boris Bonev committed
258

259
    store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
260
    torch.save(store_dict, weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
261

262

263
# training function
264
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn="l2", enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
265

266
    train_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
267

268
269
    # count iterations
    iters = 0
Boris Bonev's avatar
Boris Bonev committed
270

271
    for epoch in range(nepochs):
Boris Bonev's avatar
Boris Bonev committed
272

273
274
        # time each epoch
        epoch_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
275

276
        dataloader.dataset.set_initial_condition("random")
277
        dataloader.dataset.set_num_examples(num_examples)
Boris Bonev's avatar
Boris Bonev committed
278

279
280
        # get the solver for its convenience functions
        solver = dataloader.dataset.solver
Boris Bonev's avatar
Boris Bonev committed
281

282
283
284
        # do the training
        acc_loss = 0
        model.train()
Boris Bonev's avatar
Boris Bonev committed
285

286
        for inp, tar in dataloader:
Boris Bonev's avatar
Boris Bonev committed
287
288

            with torch.autocast(device_type="cuda", enabled=enable_amp):
Boris Bonev's avatar
Boris Bonev committed
289

290
291
292
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
Boris Bonev's avatar
Boris Bonev committed
293

294
                if loss_fn == "l2":
295
                    loss = l2loss_sphere(solver, prd, tar, relative=False)
296
                elif loss_fn == "spectral l2":
297
                    loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
298
                elif loss_fn == "h1":
299
                    loss = h1loss_sphere(solver, prd, tar, relative=False)
300
                elif loss_fn == "spectral":
301
                    loss = spectral_loss_sphere(solver, prd, tar, relative=False)
302
                elif loss_fn == "fluct":
303
304
                    loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
                else:
305
                    raise NotImplementedError(f"Unknown loss function {loss_fn}")
Boris Bonev's avatar
Boris Bonev committed
306

307
            acc_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
308

309
310
            optimizer.zero_grad(set_to_none=True)
            gscaler.scale(loss).backward()
Boris Bonev's avatar
Boris Bonev committed
311

312
313
            if log_grads and iters % log_grads == 0:
                log_weights_and_grads(model, iters=iters)
Boris Bonev's avatar
Boris Bonev committed
314

315
316
            gscaler.step(optimizer)
            gscaler.update()
Boris Bonev's avatar
Boris Bonev committed
317

318
            iters += 1
Boris Bonev's avatar
Boris Bonev committed
319

320
        acc_loss = acc_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
321

322
        dataloader.dataset.set_initial_condition("random")
323
        dataloader.dataset.set_num_examples(num_valid)
Boris Bonev's avatar
Boris Bonev committed
324

325
326
        # perform validation
        valid_loss = 0
Boris Bonev's avatar
Boris Bonev committed
327
        model.eval()
328
329
330
331
332
333
        with torch.no_grad():
            for inp, tar in dataloader:
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
                loss = l2loss_sphere(solver, prd, tar, relative=True)
Boris Bonev's avatar
Boris Bonev committed
334

335
                valid_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
336

337
        valid_loss = valid_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
338

339
340
        if scheduler is not None:
            scheduler.step(valid_loss)
Boris Bonev's avatar
Boris Bonev committed
341

342
343
        epoch_time = time.time() - epoch_start

344
345
346
347
348
        print(f"--------------------------------------------------------------------------------")
        print(f"Epoch {epoch} summary:")
        print(f"time taken: {epoch_time}")
        print(f"accumulated training loss: {acc_loss}")
        print(f"relative validation loss: {valid_loss}")
Boris Bonev's avatar
Boris Bonev committed
349

350
        if wandb.run is not None:
351
            current_lr = optimizer.param_groups[0]["lr"]
352
            wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
Boris Bonev's avatar
Boris Bonev committed
353

354
    train_time = time.time() - train_start
Boris Bonev's avatar
Boris Bonev committed
355

356
357
    print(f"--------------------------------------------------------------------------------")
    print(f"done. Training took {train_time}.")
358
    return valid_loss
Boris Bonev's avatar
Boris Bonev committed
359

360

361
def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
362

363
364
365
    # set seed
    torch.manual_seed(333)
    torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
366

367
    # set device
368
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
369
370
    if torch.cuda.is_available():
        torch.cuda.set_device(device.index)
Boris Bonev's avatar
Boris Bonev committed
371

372
    # 1 hour prediction steps
373
    dt = 1 * 3600
374
    dt_solver = 150
375
    nsteps = dt // dt_solver
376
    dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
Boris Bonev's avatar
Boris Bonev committed
377
    # There is still an issue with parallel dataloading. Do NOT use it at the moment
378
379
    # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
Boris Bonev's avatar
Boris Bonev committed
380

381
382
    nlat = dataset.nlat
    nlon = dataset.nlon
Boris Bonev's avatar
Boris Bonev committed
383
384
385
386
387
388
389
390

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # prepare dicts containing models and corresponding metrics
    models = {}
    metrics = {}

391
    from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO

    # models["sfno_sc2_layers6_e32"] = partial(
    #     SFNO,
    #     spectral_transform="sht",
    #     img_size=(nlat, nlon),
    #     grid="equiangular",
    #     num_layers=6,
    #     scale_factor=1,
    #     embed_dim=32,
    #     operator_type="driscoll-healy",
    #     activation_function="gelu",
    #     big_skip=True,
    #     pos_embed=False,
    #     use_mlp=True,
    #     normalization_layer="none",
    # )

    models["lsno_sc2_layers6_e32"] = partial(
        LSNO,
        spectral_transform="sht",
        img_size=(nlat, nlon),
        grid="equiangular",
        num_layers=6,
        scale_factor=1,
        embed_dim=32,
        operator_type="driscoll-healy",
        activation_function="gelu",
        big_skip=True,
        pos_embed=False,
        use_mlp=True,
        normalization_layer="none",
    )
Boris Bonev's avatar
Boris Bonev committed
425
426
427
428
429
430
431

    # iterate over models and train each model
    root_path = os.path.dirname(__file__)
    for model_name, model_handle in models.items():

        model = model_handle().to(device)

432
433
        print(model)

Boris Bonev's avatar
Boris Bonev committed
434
435
436
        metrics[model_name] = {}

        num_params = count_parameters(model)
437
438
        print(f"number of trainable params: {num_params}")
        metrics[model_name]["num_params"] = num_params
Boris Bonev's avatar
Boris Bonev committed
439
440

        if load_checkpoint:
441
            model.load_state_dict(torch.load(os.path.join(root_path, "checkpoints/" + model_name), weights_only=True))
Boris Bonev's avatar
Boris Bonev committed
442
443
444

        # run the training
        if train:
445
            run = wandb.init(project="sfno ablations spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
Boris Bonev's avatar
Boris Bonev committed
446
447

            # optimizer:
448
449
            optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
Boris Bonev's avatar
Boris Bonev committed
450
            gscaler = torch.GradScaler("cuda", enabled=enable_amp)
Boris Bonev's avatar
Boris Bonev committed
451
452
453

            start_time = time.time()

454
455
            print(f"Training {model_name}, single step")
            train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
Boris Bonev's avatar
Boris Bonev committed
456

457
458
459
460
            # # multistep training
            # print(f'Training {model_name}, two step')
            # optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
            # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
461
            # gscaler = torch.GradScaler(enabled=enable_amp)
462
            # dataloader.dataset.nsteps = 2 * dt//dt_solver
463
            # train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=5, nfuture=1, enable_amp=enable_amp)
464
            # dataloader.dataset.nsteps = 1 * dt//dt_solver
Boris Bonev's avatar
Boris Bonev committed
465
466
467
468
469

            training_time = time.time() - start_time

            run.finish()

470
            torch.save(model.state_dict(), os.path.join(root_path, "checkpoints/" + model_name))
Boris Bonev's avatar
Boris Bonev committed
471
472
473
474
475
476

        # set seed
        torch.manual_seed(333)
        torch.cuda.manual_seed(333)

        with torch.inference_mode():
477
478
479
480
481
482
483
            losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path, "figures/" + model_name), nsteps=nsteps, autoreg_steps=30)
            metrics[model_name]["loss_mean"] = np.mean(losses)
            metrics[model_name]["loss_std"] = np.std(losses)
            metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
            metrics[model_name]["fno_time_std"] = np.std(fno_times)
            metrics[model_name]["nwp_time_mean"] = np.mean(nwp_times)
            metrics[model_name]["nwp_time_std"] = np.std(nwp_times)
Boris Bonev's avatar
Boris Bonev committed
484
            if train:
485
                metrics[model_name]["training_time"] = training_time
Boris Bonev's avatar
Boris Bonev committed
486
487

    df = pd.DataFrame(metrics)
488
489
    df.to_pickle(os.path.join(root_path, "output_data/metrics.pkl"))

Boris Bonev's avatar
Boris Bonev committed
490
491
492
493

if __name__ == "__main__":
    import torch.multiprocessing as mp

494
495
496
    mp.set_start_method("forkserver", force=True)

    # main(train=False, load_checkpoint=True, enable_amp=False, log_grads=0)
497
    main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)