train_sfno.py 19.4 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import time

from tqdm import tqdm
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

47
from torch_harmonics.examples import PdeDataset
48
from torch_harmonics import RealSHT
49

Boris Bonev's avatar
Boris Bonev committed
50
51
# wandb logging
import wandb
52

Boris Bonev's avatar
Boris Bonev committed
53
54
wandb.login()

55

56
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
57
    loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
58
59
    if relative:
        loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
Boris Bonev's avatar
Boris Bonev committed
60

Boris Bonev's avatar
Boris Bonev committed
61
62
63
64
65
66
    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

67

68
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
69
70
    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
71
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
72
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
73
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
74
75
76

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
77
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
78
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
79
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
80
81
82
83
84
85
86
87
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

88

89
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
90
91
92
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
93
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
94
95
96

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
97
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
98
99
    coeffs = spectral_weights * coeffs
    norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
100
    loss = torch.sum(norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
101
102
103

    if relative:
        tar_coeffs = torch.view_as_real(solver.sht(tar))
104
        tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
105
106
        tar_coeffs = spectral_weights * tar_coeffs
        tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
107
        tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
108
109
110
111
112
113
114
115
        loss = loss / tar_norm2

    if not squared:
        loss = torch.sqrt(loss)
    loss = loss.mean()

    return loss

116

117
def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
Boris Bonev's avatar
Boris Bonev committed
118
119
120
    # gradient weighting factors
    lmax = solver.sht.lmax
    ls = torch.arange(lmax).float()
121
    spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
Boris Bonev's avatar
Boris Bonev committed
122
123
124

    # compute coefficients
    coeffs = torch.view_as_real(solver.sht(prd - tar))
125
    coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
Boris Bonev's avatar
Boris Bonev committed
126
127
128
    h1_coeffs = spectral_weights * coeffs
    h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
    l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
129
130
    h1_loss = torch.sum(h1_norm2, dim=(-1, -2))
    l2_loss = torch.sum(l2_norm2, dim=(-1, -2))
Boris Bonev's avatar
Boris Bonev committed
131

132
    # strictly speaking this is not exactly h1 loss
Boris Bonev's avatar
Boris Bonev committed
133
134
135
136
137
138
139
140
141
142
143
144
    if not squared:
        loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
    else:
        loss = h1_loss + l2_loss

    if relative:
        raise NotImplementedError("Relative H1 loss not implemented")

    loss = loss.mean()

    return loss

145

Boris Bonev's avatar
Boris Bonev committed
146
147
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
    # compute the weighting factor first
148
    fluct = solver.integrate_grid((tar - inp) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
149
150
    weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
    # weight = weight.reshape(*weight.shape, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
151

152
    loss = weight * solver.integrate_grid((prd - tar) ** 2, dimensionless=True, polar_opt=polar_opt)
Boris Bonev's avatar
Boris Bonev committed
153
154
155
156
157
    if relative:
        loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
    loss = torch.mean(loss)
    return loss

158

159
# rolls out the FNO and compares to the classical solver
160
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=50):
161
162
163

    model.eval()

164
165
166
167
    # make output
    if not os.path.isdir(path_root):
        os.makedirs(path_root, exist_ok=True)

168
169
170
171
    losses = np.zeros(nics)
    fno_times = np.zeros(nics)
    nwp_times = np.zeros(nics)

172
173
174
175
    # accumulation buffers for the power spectrum
    prd_mean_coeffs = []
    ref_mean_coeffs = []

176
177
178
179
180
181
182
183
184
    for iic in range(nics):
        ic = dataset.solver.random_initial_condition(mach=0.2)
        inp_mean = dataset.inp_mean
        inp_var = dataset.inp_var

        prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
        prd = prd.unsqueeze(0)
        uspec = ic.clone()

185
186
187
188
        # add IC to power spectrum series
        prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
        ref_coeffs = [prd_coeffs[0].clone()]

189
190
        # ML model
        start_time = time.time()
191
        for i in range(1, autoreg_steps + 1):
192
193
194
            # evaluate the ML model
            prd = model(prd)

195
196
            prd_coeffs.append(dataset.sht(prd[0, plot_channel]).detach().cpu().clone())

197
            if iic == nics - 1 and nskip > 0 and i % nskip == 0:
198
199
200

                # do plotting
                fig = plt.figure(figsize=(7.5, 6))
201
202
                dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
                plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
203
                plt.close()
204
205
206
207
208

        fno_times[iic] = time.time() - start_time

        # classical model
        start_time = time.time()
209
        for i in range(1, autoreg_steps + 1):
Boris Bonev's avatar
Boris Bonev committed
210

211
212
            # advance classical model
            uspec = dataset.solver.timestep(uspec, nsteps)
213
            ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
214
            ref_coeffs.append(dataset.sht(ref[plot_channel]).detach().cpu().clone())
Boris Bonev's avatar
Boris Bonev committed
215

216
            if iic == nics - 1 and i % nskip == 0 and nskip > 0:
Boris Bonev's avatar
Boris Bonev committed
217

218
                fig = plt.figure(figsize=(7.5, 6))
219
220
                dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
                plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
221
                plt.close()
Boris Bonev's avatar
Boris Bonev committed
222

223
        nwp_times[iic] = time.time() - start_time
Boris Bonev's avatar
Boris Bonev committed
224

225
        # compute power spectrum and add it to the buffers
226
227
        prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
        ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))
228

229
230
231
232
        # ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
        ref = dataset.solver.spec2grid(uspec)
        prd = prd * torch.sqrt(inp_var) + inp_mean
        losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
Boris Bonev's avatar
Boris Bonev committed
233

234
    # compute the averaged powerspectra of prediction and reference
235
236
237
238
239
240
241
242
243
244
245
246
    with torch.no_grad():
        prd_mean_coeffs = torch.stack(prd_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
        ref_mean_coeffs = torch.stack(ref_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)

        prd_mean_coeffs[..., 1:] *= 2.0
        ref_mean_coeffs[..., 1:] *= 2.0
        prd_mean_ps = prd_mean_coeffs.sum(dim=-1).contiguous()
        ref_mean_ps = ref_mean_coeffs.sum(dim=-1).contiguous()

        # split the stuff
        prd_mean_ps = [x.squeeze() for x in list(torch.split(prd_mean_ps, 1, dim=0))]
        ref_mean_ps = [x.squeeze() for x in list(torch.split(ref_mean_ps, 1, dim=0))]
247
248

    # compute the averaged powerspectrum
249
250
251
252
253
254
255
256
257
258
    for step, (pps, rps) in enumerate(zip(prd_mean_ps, ref_mean_ps)):
        fig = plt.figure(figsize=(7.5, 6))
        plt.semilogy(pps, label="prediction")
        plt.semilogy(rps, label="reference")
        plt.xlabel("$l$")
        plt.ylabel("powerspectrum")
        plt.legend()
        plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
        fig.clf()
        plt.close()
Boris Bonev's avatar
Boris Bonev committed
259

260
    return losses, fno_times, nwp_times
Boris Bonev's avatar
Boris Bonev committed
261

262

263
264
265
266
267
268
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
    """
    Helper routine intended for debugging purposes
    """
    root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
Boris Bonev's avatar
Boris Bonev committed
269

270
271
    weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
    print(weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
272

273
274
    weights_dict = {k: v for k, v in model.named_parameters()}
    grad_dict = {k: v.grad for k, v in model.named_parameters()}
Boris Bonev's avatar
Boris Bonev committed
275

276
    store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
277
    torch.save(store_dict, weights_and_grads_fname)
Boris Bonev's avatar
Boris Bonev committed
278

279

280
# training function
281
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn="l2", enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
282

283
    train_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
284

285
286
    # count iterations
    iters = 0
Boris Bonev's avatar
Boris Bonev committed
287

288
    for epoch in range(nepochs):
Boris Bonev's avatar
Boris Bonev committed
289

290
291
        # time each epoch
        epoch_start = time.time()
Boris Bonev's avatar
Boris Bonev committed
292

293
        dataloader.dataset.set_initial_condition("random")
294
        dataloader.dataset.set_num_examples(num_examples)
Boris Bonev's avatar
Boris Bonev committed
295

296
297
        # get the solver for its convenience functions
        solver = dataloader.dataset.solver
Boris Bonev's avatar
Boris Bonev committed
298

299
300
301
        # do the training
        acc_loss = 0
        model.train()
Boris Bonev's avatar
Boris Bonev committed
302

303
        for inp, tar in dataloader:
Boris Bonev's avatar
Boris Bonev committed
304
305

            with torch.autocast(device_type="cuda", enabled=enable_amp):
Boris Bonev's avatar
Boris Bonev committed
306

307
308
309
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
Boris Bonev's avatar
Boris Bonev committed
310

311
                if loss_fn == "l2":
312
                    loss = l2loss_sphere(solver, prd, tar, relative=False)
313
                elif loss_fn == "spectral l2":
314
                    loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
315
                elif loss_fn == "h1":
316
                    loss = h1loss_sphere(solver, prd, tar, relative=False)
317
                elif loss_fn == "spectral":
318
                    loss = spectral_loss_sphere(solver, prd, tar, relative=False)
319
                elif loss_fn == "fluct":
320
321
                    loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
                else:
322
                    raise NotImplementedError(f"Unknown loss function {loss_fn}")
Boris Bonev's avatar
Boris Bonev committed
323

324
            acc_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
325

326
327
            optimizer.zero_grad(set_to_none=True)
            gscaler.scale(loss).backward()
Boris Bonev's avatar
Boris Bonev committed
328

329
330
            if log_grads and iters % log_grads == 0:
                log_weights_and_grads(model, iters=iters)
Boris Bonev's avatar
Boris Bonev committed
331

332
333
            gscaler.step(optimizer)
            gscaler.update()
Boris Bonev's avatar
Boris Bonev committed
334

335
            iters += 1
Boris Bonev's avatar
Boris Bonev committed
336

337
        acc_loss = acc_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
338

339
        dataloader.dataset.set_initial_condition("random")
340
        dataloader.dataset.set_num_examples(num_valid)
Boris Bonev's avatar
Boris Bonev committed
341

342
343
        # perform validation
        valid_loss = 0
Boris Bonev's avatar
Boris Bonev committed
344
        model.eval()
345
346
347
348
349
350
        with torch.no_grad():
            for inp, tar in dataloader:
                prd = model(inp)
                for _ in range(nfuture):
                    prd = model(prd)
                loss = l2loss_sphere(solver, prd, tar, relative=True)
Boris Bonev's avatar
Boris Bonev committed
351

352
                valid_loss += loss.item() * inp.size(0)
Boris Bonev's avatar
Boris Bonev committed
353

354
        valid_loss = valid_loss / len(dataloader.dataset)
Boris Bonev's avatar
Boris Bonev committed
355

356
357
        if scheduler is not None:
            scheduler.step(valid_loss)
Boris Bonev's avatar
Boris Bonev committed
358

359
360
        epoch_time = time.time() - epoch_start

361
362
363
364
365
        print(f"--------------------------------------------------------------------------------")
        print(f"Epoch {epoch} summary:")
        print(f"time taken: {epoch_time}")
        print(f"accumulated training loss: {acc_loss}")
        print(f"relative validation loss: {valid_loss}")
Boris Bonev's avatar
Boris Bonev committed
366

367
        if wandb.run is not None:
368
            current_lr = optimizer.param_groups[0]["lr"]
369
            wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
Boris Bonev's avatar
Boris Bonev committed
370

371
    train_time = time.time() - train_start
Boris Bonev's avatar
Boris Bonev committed
372

373
374
    print(f"--------------------------------------------------------------------------------")
    print(f"done. Training took {train_time}.")
375
    return valid_loss
Boris Bonev's avatar
Boris Bonev committed
376

377

378
def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
Boris Bonev's avatar
Boris Bonev committed
379

380
381
382
    # set seed
    torch.manual_seed(333)
    torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
383

384
385
386
    # set parameters
    nfuture=0

387
    # set device
388
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
389
390
    if torch.cuda.is_available():
        torch.cuda.set_device(device.index)
Boris Bonev's avatar
Boris Bonev committed
391

392
    # 1 hour prediction steps
393
    dt = 1 * 3600
394
    dt_solver = 150
395
    nsteps = dt // dt_solver
Boris Bonev's avatar
Boris Bonev committed
396
397
398
399
    grid = "legendre-gauss"
    nlat, nlon =(181, 360)
    dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
    dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
Boris Bonev's avatar
Boris Bonev committed
400
    # There is still an issue with parallel dataloading. Do NOT use it at the moment
401
402
    # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
Boris Bonev's avatar
Boris Bonev committed
403

404
405
    nlat = dataset.nlat
    nlon = dataset.nlon
Boris Bonev's avatar
Boris Bonev committed
406
407
408
409
410
411
412
413

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # prepare dicts containing models and corresponding metrics
    models = {}
    metrics = {}

414
415
    from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO
    from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO
416

Boris Bonev's avatar
Boris Bonev committed
417
    models[f"sfno_sc2_layers4_e32"] = partial(
418
        SFNO,
419
        img_size=(nlat, nlon),
Boris Bonev's avatar
Boris Bonev committed
420
421
        grid=grid,
        hard_thresholding_fraction=0.8,
422
423
        num_layers=4,
        scale_factor=2,
424
425
426
        embed_dim=32,
        operator_type="driscoll-healy",
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
427
        big_skip=True,
428
        pos_embed=False,
Boris Bonev's avatar
Boris Bonev committed
429
        use_mlp=True,
430
431
        normalization_layer="none",
    )
Boris Bonev's avatar
Boris Bonev committed
432

Boris Bonev's avatar
Boris Bonev committed
433
    models[f"lsno_sc2_layers4_e32_morlet"] = partial(
434
435
436
437
438
439
440
441
442
443
444
445
        LSNO,
        img_size=(nlat, nlon),
        grid=grid,
        num_layers=4,
        scale_factor=2,
        embed_dim=32,
        operator_type="driscoll-healy",
        activation_function="gelu",
        big_skip=False,
        pos_embed=False,
        use_mlp=True,
        normalization_layer="none",
Boris Bonev's avatar
Boris Bonev committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        kernel_shape=[4, 4],
        encoder_kernel_shape=[4, 4],
        filter_basis_type="morlet"
    )

    models[f"lsno_sc2_layers4_e32_zernike"] = partial(
        LSNO,
        img_size=(nlat, nlon),
        grid=grid,
        num_layers=4,
        scale_factor=2,
        embed_dim=32,
        operator_type="driscoll-healy",
        activation_function="gelu",
        big_skip=False,
        pos_embed=False,
        use_mlp=True,
        normalization_layer="none",
        kernel_shape=[4],
        encoder_kernel_shape=[4],
        filter_basis_type="zernike"
467
    )
468

Boris Bonev's avatar
Boris Bonev committed
469
470
471
472
473
474
    # iterate over models and train each model
    root_path = os.path.dirname(__file__)
    for model_name, model_handle in models.items():

        model = model_handle().to(device)

475
476
        print(model)

Boris Bonev's avatar
Boris Bonev committed
477
478
479
        metrics[model_name] = {}

        num_params = count_parameters(model)
480
481
        print(f"number of trainable params: {num_params}")
        metrics[model_name]["num_params"] = num_params
Boris Bonev's avatar
Boris Bonev committed
482

483
484
485
486
        exp_dir = os.path.join(root_path, 'checkpoints', model_name)
        if not os.path.isdir(exp_dir):
            os.makedirs(exp_dir, exist_ok=True)

Boris Bonev's avatar
Boris Bonev committed
487
        if load_checkpoint:
488
            model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
Boris Bonev's avatar
Boris Bonev committed
489
490
491

        # run the training
        if train:
Boris Bonev's avatar
Boris Bonev committed
492
            run = wandb.init(project="local sno spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
Boris Bonev's avatar
Boris Bonev committed
493
494

            # optimizer:
495
496
            optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
Boris Bonev's avatar
Boris Bonev committed
497
            gscaler = torch.GradScaler("cuda", enabled=enable_amp)
Boris Bonev's avatar
Boris Bonev committed
498
499
500

            start_time = time.time()

501
            print(f"Training {model_name}, single step")
Boris Bonev's avatar
Boris Bonev committed
502
            train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=100, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
Boris Bonev's avatar
Boris Bonev committed
503

504
505
506
507
508
509
            if nfuture > 0:
                print(f'Training {model_name}, {nfuture} step')
                optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
                gscaler = amp.GradScaler(enabled=enable_amp)
                dataloader.dataset.nsteps = 2 * dt//dt_solver
510
                train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
511
                dataloader.dataset.nsteps = 1 * dt//dt_solver
Boris Bonev's avatar
Boris Bonev committed
512
513
514
515
516

            training_time = time.time() - start_time

            run.finish()

517
            torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))
Boris Bonev's avatar
Boris Bonev committed
518
519
520
521
522
523

        # set seed
        torch.manual_seed(333)
        torch.cuda.manual_seed(333)

        with torch.inference_mode():
524
            losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
525
526
527
528
529
530
            metrics[model_name]["loss_mean"] = np.mean(losses)
            metrics[model_name]["loss_std"] = np.std(losses)
            metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
            metrics[model_name]["fno_time_std"] = np.std(fno_times)
            metrics[model_name]["nwp_time_mean"] = np.mean(nwp_times)
            metrics[model_name]["nwp_time_std"] = np.std(nwp_times)
Boris Bonev's avatar
Boris Bonev committed
531
            if train:
532
                metrics[model_name]["training_time"] = training_time
Boris Bonev's avatar
Boris Bonev committed
533
534

    df = pd.DataFrame(metrics)
535
536
537
    if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
        os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
    df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))
538

Boris Bonev's avatar
Boris Bonev committed
539
540
541
542

if __name__ == "__main__":
    import torch.multiprocessing as mp

543
544
545
    mp.set_start_method("forkserver", force=True)

    # main(train=False, load_checkpoint=True, enable_amp=False, log_grads=0)
546
    main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)