Commit e731f2a5 authored by Boris Bonev's avatar Boris Bonev
Browse files

Fixing imports in examples

parent 3711e016
...@@ -30,21 +30,11 @@ ...@@ -30,21 +30,11 @@
# #
import sys
sys.path.append("..")
sys.path.append(".")
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as harmonics
import numpy as np import numpy as np
import matplotlib.pyplot as plt
try:
import cartopy.crs as ccrs
except ImportError:
ccrs = None
class SphereSolver(nn.Module): class SphereSolver(nn.Module):
...@@ -140,6 +130,7 @@ class SphereSolver(nn.Module): ...@@ -140,6 +130,7 @@ class SphereSolver(nn.Module):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. plotting routine for data on the grid. Requires cartopy for 3d plots.
""" """
import matplotlib.pyplot as plt
lons = self.lons.squeeze() - torch.pi lons = self.lons.squeeze() - torch.pi
lats = self.lats.squeeze() lats = self.lats.squeeze()
...@@ -165,8 +156,7 @@ class SphereSolver(nn.Module): ...@@ -165,8 +156,7 @@ class SphereSolver(nn.Module):
elif projection == '3d': elif projection == '3d':
if ccrs is None: import cartopy.crs as ccrs
raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0) proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
......
...@@ -30,22 +30,13 @@ ...@@ -30,22 +30,13 @@
# #
import sys
sys.path.append("..")
sys.path.append(".")
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.quadrature import * from torch_harmonics.quadrature import *
import numpy as np import numpy as np
import matplotlib.pyplot as plt
try:
import cartopy.crs as ccrs
except ImportError:
ccrs = None
class ShallowWaterSolver(nn.Module): class ShallowWaterSolver(nn.Module):
""" """
...@@ -337,6 +328,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -337,6 +328,7 @@ class ShallowWaterSolver(nn.Module):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. plotting routine for data on the grid. Requires cartopy for 3d plots.
""" """
import matplotlib.pyplot as plt
lons = self.lons.squeeze() - torch.pi lons = self.lons.squeeze() - torch.pi
lats = self.lats.squeeze() lats = self.lats.squeeze()
...@@ -362,8 +354,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -362,8 +354,7 @@ class ShallowWaterSolver(nn.Module):
elif projection == '3d': elif projection == '3d':
if ccrs is None: import cartopy.crs as ccrs
raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0) proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment