train_swe.py 19.3 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import time

from tqdm import tqdm
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

47
from torch_harmonics.examples import PdeDataset
48
from torch_harmonics import RealSHT
49

Boris Bonev's avatar
Boris Bonev committed
50
51
# wandb logging
import wandb
52
53


54
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
55
    loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
56
57
    if relative:
        loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
58

Boris Bonev's avatar
Boris Bonev committed
59
60
61
62
63
64
    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

65

66
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
67
68
    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
69
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
70
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
71
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
72
73
74

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
75
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
76
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
77
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
78
79
80
81
82
83
84
85
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

86

87
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
88
89
90
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
91
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
92
93
94

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
95
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
96
97
    coeffs = spectral_weights * coeffs
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
98
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
99
100
101

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
102
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
103
104
        tar_coeffs = spectral_weights * tar_coeffs
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
105
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
106
107
108
109
110
111
112
113
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

114

115
def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
116
117
118
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
119
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
120
121
122

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
123
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
124
125
126
    h1_coeffs = spectral_weights * coeffs
    h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
    l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
127
128
    h1_loss = torch.sum(h1_norm2, dim=(-1, -2))
    l2_loss = torch.sum(l2_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
129

130
    # strictly speaking this is not exactly h1 loss
Boris Bonev's avatar
Boris Bonev committed
131
132
133
134
135
136
137
138
139
140
141
142
    if not squared:
        loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
    else:
        loss = h1_loss + l2_loss

    if relative:
        raise NotImplementedError("Relative H1 loss not implemented")

    loss = loss.mean()

    return loss

143

Boris Bonev's avatar
Boris Bonev committed
144
145
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
    # compute the weighting factor first
146
    fluct = solver.integrate_grid((tar - inp) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
147
148
    weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
    # weight = weight.reshape(*weight.shape, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
149

150
    loss = weight * solver.integrate_grid((prd - tar) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
151
152
153
154
155
    if relative:
        loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
    loss = torch.mean(loss)
    return loss

156

157
# rolls out the FNO and compares to the classical solver
158
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=50):
159
160
161

    model.eval()

162
163
164
165
    # make output
    if not os.path.isdir(path_root):
        os.makedirs(path_root, exist_ok=True)

166
167
168
169
    losses = np.zeros(nics)
    fno_times = np.zeros(nics)
    nwp_times = np.zeros(nics)

170
171
172
173
    # accumulation buffers for the power spectrum
    prd_mean_coeffs = []
    ref_mean_coeffs = []

174
175
176
177
178
179
180
181
182
    for iic in range(nics):
        ic = dataset.solver.random_initial_condition(mach=0.2)
        inp_mean = dataset.inp_mean
        inp_var = dataset.inp_var

        prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
        prd = prd.unsqueeze(0)
        uspec = ic.clone()

183
184
185
186
        # add IC to power spectrum series
        prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
        ref_coeffs = [prd_coeffs[0].clone()]

187
188
        # ML model
        start_time = time.time()
189
        for i in range(1, autoreg_steps + 1):
190
191
192
            # evaluate the ML model
            prd = model(prd)

193
194
            prd_coeffs.append(dataset.sht(prd[0, plot_channel]).detach().cpu().clone())

195
            if iic == nics - 1 and nskip > 0 and i % nskip == 0:
196
197
198

                # do plotting
                fig = plt.figure(figsize=(7.5, 6))
199
200
                dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
                plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
201
                plt.close()
202
203
204
205
206

        fno_times[iic] = time.time() - start_time

        # classical model
        start_time = time.time()
207
        for i in range(1, autoreg_steps + 1):
Boris Bonev's avatar
Boris Bonev committed
208

209
210
            # advance classical model
            uspec = dataset.solver.timestep(uspec, nsteps)
211
            ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
212
            ref_coeffs.append(dataset.sht(ref[plot_channel]).detach().cpu().clone())
Boris Bonev's avatar
Boris Bonev committed
213

214
            if iic == nics - 1 and i % nskip == 0 and nskip > 0:
Boris Bonev's avatar
Boris Bonev committed
215

216
                fig = plt.figure(figsize=(7.5, 6))
217
218
                dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
                plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
219
                plt.close()
Boris Bonev's avatar
Boris Bonev committed
220

221
        nwp_times[iic] = time.time() - start_time
Boris Bonev's avatar
Boris Bonev committed
222

223
        # compute power spectrum and add it to the buffers
224
225
        prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
        ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))
226

227
228
229
230
        # ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
        ref = dataset.solver.spec2grid(uspec)
        prd = prd * torch.sqrt(inp_var) + inp_mean
        losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
Boris Bonev's avatar
Boris Bonev committed
231

232
    # compute the averaged powerspectra of prediction and reference
233
234
235
236
237
238
239
240
241
242
243
244
    with torch.no_grad():
        prd_mean_coeffs = torch.stack(prd_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
        ref_mean_coeffs = torch.stack(ref_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)

        prd_mean_coeffs[..., 1:] *= 2.0
        ref_mean_coeffs[..., 1:] *= 2.0
        prd_mean_ps = prd_mean_coeffs.sum(dim=-1).contiguous()
        ref_mean_ps = ref_mean_coeffs.sum(dim=-1).contiguous()

        # split the stuff
        prd_mean_ps = [x.squeeze() for x in list(torch.split(prd_mean_ps, 1, dim=0))]
        ref_mean_ps = [x.squeeze() for x in list(torch.split(ref_mean_ps, 1, dim=0))]
245
246

    # compute the averaged powerspectrum
247
248
249
250
251
252
253
254
255
256
    for step, (pps, rps) in enumerate(zip(prd_mean_ps, ref_mean_ps)):
        fig = plt.figure(figsize=(7.5, 6))
        plt.semilogy(pps, label="prediction")
        plt.semilogy(rps, label="reference")
        plt.xlabel("$l$")
        plt.ylabel("powerspectrum")
        plt.legend()
        plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
        fig.clf()
        plt.close()
Boris Bonev's avatar
Boris Bonev committed
257

258
    return losses, fno_times, nwp_times
Boris Bonev's avatar
Boris Bonev committed
259

260

261
262
263
264
265
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
    """
    Helper routine intended for debugging purposes
    """
266
    root_path = os.path.join(os.getcwd(), "weights_and_grads")
Boris Bonev's avatar
Boris Bonev committed
267

268
269
    weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
    print(weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
270

271
272
    weights_dict = {k: v for k, v in model.named_parameters()}
    grad_dict = {k: v.grad for k, v in model.named_parameters()}
Boris Bonev's avatar
Boris Bonev committed
273

274
    store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
275
    torch.save(store_dict, weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
276

277

278
# training function
279
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn="l2", enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
280

281
    train_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
282

283
284
    # count iterations
    iters = 0
Boris Bonev's avatar
Boris Bonev committed
285

286
    for epoch in range(nepochs):
Boris Bonev's avatar
Boris Bonev committed
287

288
289
        # time each epoch
        epoch_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
290

291
        dataloader.dataset.set_initial_condition("random")
292
        dataloader.dataset.set_num_examples(num_examples)
Boris Bonev's avatar
Boris Bonev committed
293

294
295
        # get the solver for its convenience functions
        solver = dataloader.dataset.solver
Boris Bonev's avatar
Boris Bonev committed
296

297
298
299
        # do the training
        acc_loss = 0
        model.train()
Boris Bonev's avatar
Boris Bonev committed
300

301
        for inp, tar in dataloader:
Boris Bonev's avatar
Boris Bonev committed
302
303

            with torch.autocast(device_type="cuda", enabled=enable_amp):
Boris Bonev's avatar
Boris Bonev committed
304

305
306
307
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
Boris Bonev's avatar
Boris Bonev committed
308

309
                if loss_fn == "l2":
310
                    loss = l2loss_sphere(solver, prd, tar, relative=False)
311
                elif loss_fn == "spectral l2":
312
                    loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
313
                elif loss_fn == "h1":
314
                    loss = h1loss_sphere(solver, prd, tar, relative=False)
315
                elif loss_fn == "spectral":
316
                    loss = spectral_loss_sphere(solver, prd, tar, relative=False)
317
                elif loss_fn == "fluct":
318
319
                    loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
                else:
320
                    raise NotImplementedError(f"Unknown loss function {loss_fn}")
Boris Bonev's avatar
Boris Bonev committed
321

322
            acc_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
323

324
325
            optimizer.zero_grad(set_to_none=True)
            gscaler.scale(loss).backward()
Boris Bonev's avatar
Boris Bonev committed
326

327
328
            if log_grads and iters % log_grads == 0:
                log_weights_and_grads(model, iters=iters)
Boris Bonev's avatar
Boris Bonev committed
329

330
331
            gscaler.step(optimizer)
            gscaler.update()
Boris Bonev's avatar
Boris Bonev committed
332

333
            iters += 1
Boris Bonev's avatar
Boris Bonev committed
334

335
        acc_loss = acc_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
336

337
        dataloader.dataset.set_initial_condition("random")
338
        dataloader.dataset.set_num_examples(num_valid)
Boris Bonev's avatar
Boris Bonev committed
339

340
341
        # perform validation
        valid_loss = 0
Boris Bonev's avatar
Boris Bonev committed
342
        model.eval()
343
344
345
346
347
348
        with torch.no_grad():
            for inp, tar in dataloader:
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
                loss = l2loss_sphere(solver, prd, tar, relative=True)
Boris Bonev's avatar
Boris Bonev committed
349

350
                valid_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
351

352
        valid_loss = valid_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
353

354
355
        if scheduler is not None:
            scheduler.step(valid_loss)
Boris Bonev's avatar
Boris Bonev committed
356

357
358
        epoch_time = time.time() - epoch_start

359
360
361
362
363
        print(f"--------------------------------------------------------------------------------")
        print(f"Epoch {epoch} summary:")
        print(f"time taken: {epoch_time}")
        print(f"accumulated training loss: {acc_loss}")
        print(f"relative validation loss: {valid_loss}")
Boris Bonev's avatar
Boris Bonev committed
364

365
        if wandb.run is not None:
366
            current_lr = optimizer.param_groups[0]["lr"]
367
            wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
Boris Bonev's avatar
Boris Bonev committed
368

369
    train_time = time.time() - train_start
Boris Bonev's avatar
Boris Bonev committed
370

371
372
    print(f"--------------------------------------------------------------------------------")
    print(f"done. Training took {train_time}.")
373
    return valid_loss
Boris Bonev's avatar
Boris Bonev committed
374

375

376
def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
377

378
379
380
    # set seed
    torch.manual_seed(333)
    torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
381

382
383
384
    # login
    wandb.login()

385
386
387
    # set parameters
    nfuture=0

388
    # set device
389
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
390
391
    if torch.cuda.is_available():
        torch.cuda.set_device(device.index)
Boris Bonev's avatar
Boris Bonev committed
392

393
    # 1 hour prediction steps
394
    dt = 1 * 3600
395
    dt_solver = 150
396
    nsteps = dt // dt_solver
Boris Bonev's avatar
Boris Bonev committed
397
    grid = "legendre-gauss"
398
    nlat, nlon = (257, 512)
Boris Bonev's avatar
Boris Bonev committed
399
400
    dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
    dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
Boris Bonev's avatar
Boris Bonev committed
401
    # There is still an issue with parallel dataloading. Do NOT use it at the moment
402
403
    # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
Boris Bonev's avatar
Boris Bonev committed
404

405
406
    nlat = dataset.nlat
    nlon = dataset.nlon
Boris Bonev's avatar
Boris Bonev committed
407
408
409
410
411
412
413
414

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # prepare dicts containing models and corresponding metrics
    models = {}
    metrics = {}

415
416
    from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO
    from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO
417

Boris Bonev's avatar
Boris Bonev committed
418
    models[f"sfno_sc2_layers4_e32"] = partial(
419
        SFNO,
420
        img_size=(nlat, nlon),
Boris Bonev's avatar
Boris Bonev committed
421
422
        grid=grid,
        hard_thresholding_fraction=0.8,
423
424
        num_layers=4,
        scale_factor=2,
425
426
        embed_dim=32,
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
427
        big_skip=True,
428
        pos_embed=False,
Boris Bonev's avatar
Boris Bonev committed
429
        use_mlp=True,
430
431
        normalization_layer="none",
    )
Boris Bonev's avatar
Boris Bonev committed
432

Boris Bonev's avatar
Boris Bonev committed
433
    models[f"lsno_sc2_layers4_e32_morlet"] = partial(
434
435
436
437
438
439
440
        LSNO,
        img_size=(nlat, nlon),
        grid=grid,
        num_layers=4,
        scale_factor=2,
        embed_dim=32,
        activation_function="gelu",
441
        big_skip=True,
442
443
444
        pos_embed=False,
        use_mlp=True,
        normalization_layer="none",
445
446
        kernel_shape=(2, 2),
        encoder_kernel_shape=(2, 2),
447
448
        filter_basis_type="morlet",
        upsample_sht = True,
Boris Bonev's avatar
Boris Bonev committed
449
450
451
452
453
454
455
456
457
458
    )

    models[f"lsno_sc2_layers4_e32_zernike"] = partial(
        LSNO,
        img_size=(nlat, nlon),
        grid=grid,
        num_layers=4,
        scale_factor=2,
        embed_dim=32,
        activation_function="gelu",
459
        big_skip=True,
Boris Bonev's avatar
Boris Bonev committed
460
461
462
        pos_embed=False,
        use_mlp=True,
        normalization_layer="none",
463
464
        kernel_shape=(4),
        encoder_kernel_shape=(4),
465
466
        filter_basis_type="zernike",
        upsample_sht = True,
467
    )
468

Boris Bonev's avatar
Boris Bonev committed
469
    # iterate over models and train each model
470
    root_path = os.getcwd()
Boris Bonev's avatar
Boris Bonev committed
471
472
473
474
    for model_name, model_handle in models.items():

        model = model_handle().to(device)

475
476
        print(model)

Boris Bonev's avatar
Boris Bonev committed
477
478
479
        metrics[model_name] = {}

        num_params = count_parameters(model)
480
481
        print(f"number of trainable params: {num_params}")
        metrics[model_name]["num_params"] = num_params
Boris Bonev's avatar
Boris Bonev committed
482

483
484
485
486
        exp_dir = os.path.join(root_path, 'checkpoints', model_name)
        if not os.path.isdir(exp_dir):
            os.makedirs(exp_dir, exist_ok=True)

Boris Bonev's avatar
Boris Bonev committed
487
        if load_checkpoint:
488
            model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
Boris Bonev's avatar
Boris Bonev committed
489
490
491

        # run the training
        if train:
Boris Bonev's avatar
Boris Bonev committed
492
            run = wandb.init(project="local sno spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
Boris Bonev's avatar
Boris Bonev committed
493
494

            # optimizer:
495
496
            optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
Boris Bonev's avatar
Boris Bonev committed
497
            gscaler = torch.GradScaler("cuda", enabled=enable_amp)
Boris Bonev's avatar
Boris Bonev committed
498
499
500

            start_time = time.time()

501
            print(f"Training {model_name}, single step")
502
            train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
Boris Bonev's avatar
Boris Bonev committed
503

504
505
506
507
508
509
            if nfuture > 0:
                print(f'Training {model_name}, {nfuture} step')
                optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
                gscaler = amp.GradScaler(enabled=enable_amp)
                dataloader.dataset.nsteps = 2 * dt//dt_solver
510
                train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
511
                dataloader.dataset.nsteps = 1 * dt//dt_solver
Boris Bonev's avatar
Boris Bonev committed
512
513
514
515
516

            training_time = time.time() - start_time

            run.finish()

517
            torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))
Boris Bonev's avatar
Boris Bonev committed
518
519
520
521
522
523

        # set seed
        torch.manual_seed(333)
        torch.cuda.manual_seed(333)

        with torch.inference_mode():
524
            losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
525
526
527
528
529
530
            metrics[model_name]["loss_mean"] = np.mean(losses)
            metrics[model_name]["loss_std"] = np.std(losses)
            metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
            metrics[model_name]["fno_time_std"] = np.std(fno_times)
            metrics[model_name]["nwp_time_mean"] = np.mean(nwp_times)
            metrics[model_name]["nwp_time_std"] = np.std(nwp_times)
Boris Bonev's avatar
Boris Bonev committed
531
            if train:
532
                metrics[model_name]["training_time"] = training_time
Boris Bonev's avatar
Boris Bonev committed
533
534

    df = pd.DataFrame(metrics)
535
536
537
    if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
        os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
    df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))
538

Boris Bonev's avatar
Boris Bonev committed
539
540
541
542

if __name__ == "__main__":
    import torch.multiprocessing as mp

543
544
545
    mp.set_start_method("forkserver", force=True)

    # main(train=False, load_checkpoint=True, enable_amp=False, log_grads=0)
546
    main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)