"examples/community/pipeline_animatediff_img2video.py" did not exist on "04d696d65053644775b104cb3af92aff8338e6fc"
shallow_water_equations.py 19.3 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
Boris Bonev's avatar
Boris Bonev committed
4
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#


import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35
import torch_harmonics as th
36
from torch_harmonics.quadrature import _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
37

Thorsten Kurth's avatar
Thorsten Kurth committed
38
import math
Boris Bonev's avatar
Boris Bonev committed
39
40
41
42
43
import numpy as np


class ShallowWaterSolver(nn.Module):
    """
apaaris's avatar
apaaris committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    Shallow Water Equations (SWE) solver class for spherical geometry.
    
    Interface inspired by pyspharm and SHTns. Solves the shallow water equations
    on a rotating sphere using spectral methods.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    dt : float
        Time step size
    lmax : int, optional
        Maximum l mode for spherical harmonics, by default None
    mmax : int, optional
        Maximum m mode for spherical harmonics, by default None
    grid : str, optional
        Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    radius : float, optional
        Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
    omega : float, optional
        Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
    gravity : float, optional
        Gravitational acceleration in m/s², by default 9.80616
    havg : float, optional
        Average height in meters, by default 10.e3
    hamp : float, optional
        Height amplitude in meters, by default 120.
Boris Bonev's avatar
Boris Bonev committed
73
74
    """

75
    def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
Boris Bonev's avatar
Boris Bonev committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                 omega=7.292E-5, gravity=9.80616, havg=10.e3, hamp=120.):
        super().__init__()

        # time stepping param
        self.dt = dt

        # grid parameters
        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid

        # physical sonstants
        self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64))
        self.register_buffer('omega', torch.as_tensor(omega, dtype=torch.float64))
        self.register_buffer('gravity', torch.as_tensor(gravity, dtype=torch.float64))
        self.register_buffer('havg', torch.as_tensor(havg, dtype=torch.float64))
        self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64))

        # SHT
Boris Bonev's avatar
Boris Bonev committed
95
96
97
98
        self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.vsht = th.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.ivsht = th.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104

        self.lmax = lmax or self.sht.lmax
        self.mmax = lmax or self.sht.mmax

        # compute gridpoints
        if self.grid == "legendre-gauss":
Boris Bonev's avatar
Boris Bonev committed
105
            cost, quad_weights = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
106
        elif self.grid == "lobatto":
Boris Bonev's avatar
Boris Bonev committed
107
            cost, quad_weights = th.quadrature.lobatto_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
108
        elif self.grid == "equiangular":
Boris Bonev's avatar
Boris Bonev committed
109
            cost, quad_weights = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
110

Thorsten Kurth's avatar
Thorsten Kurth committed
111
        quad_weights = quad_weights.reshape(-1, 1)
Boris Bonev's avatar
Boris Bonev committed
112
113

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
114
115
        lats = -torch.arcsin(cost)
        lons = _precompute_longitudes(self.nlon)
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        self.lmax = self.sht.lmax
        self.mmax = self.sht.mmax

        # compute the laplace and inverse laplace operators
        l = torch.arange(0, self.lmax).reshape(self.lmax, 1).double()
        l = l.expand(self.lmax, self.mmax)
        # the laplace operator acting on the coefficients is given by - l (l + 1)
        lap = - l * (l + 1) / self.radius**2
        invlap = - self.radius**2 / l / (l + 1)
        invlap[0] = 0.

        # compute coriolis force
        coriolis = 2 * self.omega * torch.sin(lats).reshape(self.nlat, 1)

        # hyperdiffusion
        hyperdiff = torch.exp(torch.asarray((-self.dt / 2 / 3600.)*(lap / lap[-1, 0])**4))

        # register all
        self.register_buffer('lats', lats)
        self.register_buffer('lons', lons)
        self.register_buffer('l', l)
        self.register_buffer('lap', lap)
        self.register_buffer('invlap', invlap)
        self.register_buffer('coriolis', coriolis)
        self.register_buffer('hyperdiff', hyperdiff)
        self.register_buffer('quad_weights', quad_weights)

    def grid2spec(self, ugrid):
        """
apaaris's avatar
apaaris committed
146
147
148
149
150
151
152
153
154
155
156
        Convert spatial data to spectral coefficients.
        
        Parameters
        -----------
        ugrid : torch.Tensor
            Spatial data tensor
            
        Returns
        -------
        torch.Tensor
            Spectral coefficients
Boris Bonev's avatar
Boris Bonev committed
157
158
159
160
161
        """
        return self.sht(ugrid)

    def spec2grid(self, uspec):
        """
apaaris's avatar
apaaris committed
162
163
164
165
166
167
168
169
170
171
172
        Convert spectral coefficients to spatial data.
        
        Parameters
        -----------
        uspec : torch.Tensor
            Spectral coefficients tensor
            
        Returns
        -------
        torch.Tensor
            Spatial data
Boris Bonev's avatar
Boris Bonev committed
173
174
175
176
        """
        return self.isht(uspec)

    def vrtdivspec(self, ugrid):
apaaris's avatar
apaaris committed
177
178
179
180
181
182
183
184
185
186
187
188
189
        """
        Compute vorticity and divergence from velocity field.
        
        Parameters
        -----------
        ugrid : torch.Tensor
            Velocity field in spatial coordinates
            
        Returns
        -------
        torch.Tensor
            Spectral coefficients of vorticity and divergence
        """
Boris Bonev's avatar
Boris Bonev committed
190
191
192
193
194
        vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
        return vrtdivspec

    def getuv(self, vrtdivspec):
        """
apaaris's avatar
apaaris committed
195
196
197
198
199
200
201
202
203
204
205
        Compute wind vector from spectral coefficients of vorticity and divergence.
        
        Parameters
        -----------
        vrtdivspec : torch.Tensor
            Spectral coefficients of vorticity and divergence
            
        Returns
        -------
        torch.Tensor
            Wind vector field in spatial coordinates
Boris Bonev's avatar
Boris Bonev committed
206
207
208
209
210
        """
        return self.ivsht( self.invlap * vrtdivspec / self.radius)

    def gethuv(self, uspec):
        """
apaaris's avatar
apaaris committed
211
212
213
214
215
216
217
218
219
220
221
        Compute height and wind vector from spectral coefficients.
        
        Parameters
        -----------
        uspec : torch.Tensor
            Spectral coefficients [height, vorticity, divergence]
            
        Returns
        -------
        torch.Tensor
            Combined height and wind vector field
Boris Bonev's avatar
Boris Bonev committed
222
223
224
225
226
227
228
        """
        hgrid = self.spec2grid(uspec[:1])
        uvgrid = self.getuv(uspec[1:])
        return torch.cat((hgrid, uvgrid), dim=-3)

    def potential_vorticity(self, uspec):
        """
apaaris's avatar
apaaris committed
229
230
231
232
233
234
235
236
237
238
239
        Compute potential vorticity from spectral coefficients.
        
        Parameters
        -----------
        uspec : torch.Tensor
            Spectral coefficients [height, vorticity, divergence]
            
        Returns
        -------
        torch.Tensor
            Potential vorticity field
Boris Bonev's avatar
Boris Bonev committed
240
241
242
243
244
245
246
        """
        ugrid = self.spec2grid(uspec)
        pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
        return pvrt

    def dimensionless(self, uspec):
        """
apaaris's avatar
apaaris committed
247
248
249
250
251
252
253
254
255
256
257
        Remove dimensions from variables for dimensionless analysis.
        
        Parameters
        -----------
        uspec : torch.Tensor
            Spectral coefficients with dimensions
            
        Returns
        -------
        torch.Tensor
            Dimensionless spectral coefficients
Boris Bonev's avatar
Boris Bonev committed
258
259
260
261
262
263
264
265
        """
        uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
        # vorticity is measured in 1/s so we normalize using sqrt(g h) / r
        uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
        return uspec

    def dudtspec(self, uspec):
        """
apaaris's avatar
apaaris committed
266
267
268
269
270
271
272
273
274
275
276
        Compute time derivatives from solution represented in spectral coefficients.
        
        Parameters
        -----------
        uspec : torch.Tensor
            Spectral coefficients [height, vorticity, divergence]
            
        Returns
        -------
        torch.Tensor
            Time derivatives of spectral coefficients
Boris Bonev's avatar
Boris Bonev committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        """
        dudtspec = torch.zeros_like(uspec)

        # compute the derivatives - this should be incorporated into the solver:
        ugrid = self.spec2grid(uspec)
        uvgrid = self.getuv(uspec[1:])

        # phi = ugrid[0]
        # vrtdiv = ugrid[1:]

        tmp = uvgrid * (ugrid[1] + self.coriolis)
        tmpspec = self.vrtdivspec(tmp)
        dudtspec[2] = tmpspec[0]
        dudtspec[1] = -1 * tmpspec[1]

        tmp = uvgrid * ugrid[0]
        tmp = self.vrtdivspec(tmp)
        dudtspec[0] = -1 * tmp[1]

        tmpspec = self.grid2spec(ugrid[0] + 0.5 * (uvgrid[0]**2 + uvgrid[1]**2))
        dudtspec[2] = dudtspec[2] - self.lap * tmpspec

        return dudtspec

    def galewsky_initial_condition(self):
        """
apaaris's avatar
apaaris committed
303
        Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Boris Bonev's avatar
Boris Bonev committed
304

305
306
307
308
        Parameters
        ----------
        None
        
apaaris's avatar
apaaris committed
309
310
311
312
        Returns
        -------
        torch.Tensor
            Initial spectral coefficients for the Galewsky test case
313
314
315
316
317

        References
        ----------
        [1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
            DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Boris Bonev's avatar
Boris Bonev committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        """
        device = self.lap.device

        umax = 80.
        phi0 = torch.asarray(torch.pi / 7., device=device)
        phi1 = torch.asarray(0.5 * torch.pi - phi0, device=device)
        phi2 = 0.25 * torch.pi
        en = torch.exp(torch.asarray(-4.0 / (phi1 - phi0)**2, device=device))
        alpha = 1. / 3.
        beta = 1. / 15.

        lats, lons = torch.meshgrid(self.lats, self.lons)

        u1 = (umax/en)*torch.exp(1./((lats-phi0)*(lats-phi1)))
        ugrid = torch.where(torch.logical_and(lats < phi1, lats > phi0), u1, torch.zeros(self.nlat, self.nlon, device=device))
        vgrid = torch.zeros((self.nlat, self.nlon), device=device)
        hbump = self.hamp * torch.cos(lats) * torch.exp(-((lons-torch.pi)/alpha)**2) * torch.exp(-(phi2-lats)**2/beta)

        # intial velocity field
        ugrid = torch.stack((ugrid, vgrid))
        # intial vorticity/divergence field
        vrtdivspec = self.vrtdivspec(ugrid)
        vrtdivgrid = self.spec2grid(vrtdivspec)

        # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
        tmp = ugrid * (vrtdivgrid + self.coriolis)
        tmpspec = self.vrtdivspec(tmp)
        tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
        phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity*(self.havg + hbump))

        # assemble solution
        uspec = torch.zeros(3, self.lmax, self.mmax, dtype=vrtdivspec.dtype, device=device)
        uspec[0] = phispec
        uspec[1:] = vrtdivspec

        return torch.tril(uspec)

    def random_initial_condition(self, mach=0.1) -> torch.Tensor:
        """
apaaris's avatar
apaaris committed
357
358
359
360
361
362
363
364
365
366
367
        Generate random initial condition on the sphere.
        
        Parameters
        -----------
        mach : float, optional
            Mach number for scaling the random perturbations, by default 0.1
            
        Returns
        -------
        torch.Tensor
            Random initial spectral coefficients
Boris Bonev's avatar
Boris Bonev committed
368
369
370
371
372
        """
        device = self.lap.device
        ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64

        # mach number relative to wave speed
373
        llimit = mlimit = 120
Boris Bonev's avatar
Boris Bonev committed
374
375
376
377
378
379
380
381
382

        # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # vgrid = vamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # ugrid = torch.stack((ugrid, vgrid))

        # initial geopotential
        uspec = torch.zeros(3, self.lmax, self.mmax, dtype=ctype, device=self.lap.device)
        uspec[:, :llimit, :mlimit] = torch.sqrt(torch.tensor(4 * torch.pi / llimit / (llimit+1), device=device, dtype=ctype)) * torch.randn_like(uspec[:, :llimit, :mlimit])
383

Boris Bonev's avatar
Boris Bonev committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        uspec[0] = self.gravity * self.hamp * uspec[0]
        uspec[0, 0, 0] += torch.sqrt(torch.tensor(4 * torch.pi, device=device, dtype=ctype)) * self.havg * self.gravity
        uspec[1:] = mach * uspec[1:] * torch.sqrt(self.gravity * self.havg) / self.radius
        # uspec[1:] = self.vrtdivspec(self.spec2grid(uspec[1:]) * torch.cos(self.lats.reshape(-1, 1)))

        # # intial velocity field
        # ugrid = uamp * self.spec2grid(uspec[1])
        # vgrid = vamp * self.spec2grid(uspec[2])
        # ugrid = torch.stack((ugrid, vgrid))



        # # intial vorticity/divergence field
        # vrtdivspec = self.vrtdivspec(ugrid)
        # vrtdivgrid = self.spec2grid(vrtdivspec)

        # # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
        # tmp = ugrid * (vrtdivgrid + self.coriolis)
        # tmpspec = self.vrtdivspec(tmp)
        # tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
        # phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity * hgrid)

        # # assemble solution
        # uspec = torch.zeros(3, self.lmax, self.mmax, dtype=phispec.dtype, device=device)
        # uspec[0] = phispec
        # uspec[1:] = vrtdivspec
410

Boris Bonev's avatar
Boris Bonev committed
411
412
413
414
415
        return torch.tril(uspec)

    def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
        """
        Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
416
417
418
419
420
421
422
423
424
425
426
427

        Parameters
        ----------
        uspec : torch.Tensor
            Spectral coefficients [height, vorticity, divergence]
        nsteps : int
            Number of time steps to integrate

        Returns
        -------
        torch.Tensor
            Integrated spectral coefficients
Boris Bonev's avatar
Boris Bonev committed
428
429
430
431
432
433
434
435
436
437
438
        """

        dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)

        # pointers to indicate the most current result
        inew = 0
        inow = 1
        iold = 2

        for iter in range(nsteps):
            dudtspec[inew] = self.dudtspec(uspec)
439

Boris Bonev's avatar
Boris Bonev committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            # update vort,div,phiv with third-order adams-bashforth.
            # forward euler, then 2nd-order adams-bashforth time steps to start.
            if iter == 0:
                dudtspec[inow] = dudtspec[inew]
                dudtspec[iold] = dudtspec[inew]
            elif iter == 1:
                dudtspec[iold] = dudtspec[inew]

            uspec = uspec + self.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] )

            # implicit hyperdiffusion for vort and div.
            uspec[1:] = self.hyperdiff * uspec[1:]

            # cycle through the indices
            inew = (inew - 1) % 3
            inow = (inow - 1) % 3
            iold = (iold - 1) % 3
457

Boris Bonev's avatar
Boris Bonev committed
458
459
460
        return uspec

    def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        """
        Integrate the solution on the grid.

        Parameters
        ----------
        ugrid : torch.Tensor
            Grid data
        dimensionless : bool, optional
            Whether to use dimensionless units, by default False
        polar_opt : int, optional
            Number of polar points to exclude, by default 0

        Returns
        -------
        torch.Tensor
            Integrated grid data
        """
478
        dlon = 2 * torch.pi / self.nlon
Boris Bonev's avatar
Boris Bonev committed
479
480
481
482
483
484
485
486
        radius = 1 if dimensionless else self.radius
        if polar_opt > 0:
            out = torch.sum(ugrid[..., polar_opt:-polar_opt, :] * self.quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1))
        else:
            out = torch.sum(ugrid * self.quad_weights * dlon * radius**2, dim=(-2, -1))
        return out


Boris Bonev's avatar
Boris Bonev committed
487
    def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
Boris Bonev's avatar
Boris Bonev committed
488
489
490
        """
        plotting routine for data on the grid. Requires cartopy for 3d plots.
        """
Boris Bonev's avatar
Boris Bonev committed
491
        import matplotlib.pyplot as plt
Boris Bonev's avatar
Boris Bonev committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

        lons = self.lons.squeeze() - torch.pi
        lats = self.lats.squeeze()

        if data.is_cuda:
            data = data.cpu()
            lons = lons.cpu()
            lats = lats.cpu()

        Lons, Lats = np.meshgrid(lons, lats)

        if projection == 'mollweide':

            #ax = plt.gca(projection=projection)
            ax = fig.add_subplot(projection=projection)
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin)
            # ax.set_title("Elevation map of mars")
            ax.grid(True)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            plt.colorbar(im, orientation='horizontal')
            plt.title(title)

        elif projection == '3d':

Boris Bonev's avatar
Boris Bonev committed
517
            import cartopy.crs as ccrs
Boris Bonev's avatar
Boris Bonev committed
518

519
            proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
Boris Bonev's avatar
Boris Bonev committed
520
521
522

            #ax = plt.gca(projection=proj, frameon=True)
            ax = fig.add_subplot(projection=proj)
Thorsten Kurth's avatar
Thorsten Kurth committed
523
524
            Lons = Lons*180/math.pi
            Lats = Lats*180/math.pi
Boris Bonev's avatar
Boris Bonev committed
525
526
527
528
529

            # contour data over the map.
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
            plt.title(title, y=1.05)

530
531
532
533
534
535
536
537
        elif projection == 'robinson':

            import cartopy.crs as ccrs

            proj = ccrs.Robinson(central_longitude=0.0)

            #ax = plt.gca(projection=proj, frameon=True)
            ax = fig.add_subplot(projection=proj)
Thorsten Kurth's avatar
Thorsten Kurth committed
538
539
            Lons = Lons*180/math.pi
            Lats = Lats*180/math.pi
540
541
542
543
544

            # contour data over the map.
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
            plt.title(title, y=1.05)

Boris Bonev's avatar
Boris Bonev committed
545
546
547
548
549
550
551
        else:
            raise NotImplementedError

        return im

    def plot_specdata(self, data, fig, **kwargs):
        return self.plot_griddata(self.isht(data), fig, **kwargs)