Commit 4d04d055 authored by mashun1's avatar mashun1
Browse files

graphcast

parents
Pipeline #1048 failed with stages
in 0 seconds
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for creating icosahedral meshes."""
import itertools
from typing import List, NamedTuple, Sequence, Tuple
import numpy as np
from scipy.spatial import transform
class TriangularMesh(NamedTuple):
"""Data structure for triangular meshes.
Attributes:
vertices: spatial positions of the vertices of the mesh of shape
[num_vertices, num_dims].
faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
integer indices into `vertices`.
"""
vertices: np.ndarray
faces: np.ndarray
def merge_meshes(
mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
"""Merges all meshes into one. Assumes the last mesh is the finest.
Args:
mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
vertices and faces may contain those from preceding, coarser levels.
Returns:
`TriangularMesh` for which the vertices correspond to the highest
resolution mesh in the hierarchy, and the faces are the join set of the
faces at all levels of the hierarchy.
"""
for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
num_nodes_mesh_i = mesh_i.vertices.shape[0]
assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
return TriangularMesh(
vertices=mesh_list[-1].vertices,
faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
def get_hierarchy_of_triangular_meshes_for_sphere(
splits: int) -> List[TriangularMesh]:
"""Returns a sequence of meshes, each with triangularization sphere.
Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
circumscribed unit sphere. Then, each triangular face is iteratively
subdivided into 4 triangular faces `splits` times. The new vertices are then
projected back onto the unit sphere. All resulting meshes are returned in a
list, from lowest to highest resolution.
The vertices in each face are specified in counter-clockwise order as
observed from the outside the icosahedron.
Args:
splits: How many times to split each triangle.
Returns:
Sequence of `TriangularMesh`s of length `splits + 1` each with:
vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
Each row contains three indices into the vertices array, indicating
the vertices adjacent to the face. Always with positive orientation
(counterclock-wise when looking from the outside).
"""
current_mesh = get_icosahedron()
output_meshes = [current_mesh]
for _ in range(splits):
current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
output_meshes.append(current_mesh)
return output_meshes
def get_icosahedron() -> TriangularMesh:
"""Returns a regular icosahedral mesh with circumscribed unit sphere.
See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
for details on the construction of the regular icosahedron.
The vertices in each face are specified in counter-clockwise order as observed
from the outside of the icosahedron.
Returns:
TriangularMesh with:
vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
Each row contains three indices into the vertices array, indicating
the vertices adjacent to the face. Always with positive orientation (
counterclock-wise when looking from the outside).
"""
phi = (1 + np.sqrt(5)) / 2
vertices = []
for c1 in [1., -1.]:
for c2 in [phi, -phi]:
vertices.append((c1, c2, 0.))
vertices.append((0., c1, c2))
vertices.append((c2, 0., c1))
vertices = np.array(vertices, dtype=np.float32)
vertices /= np.linalg.norm([1., phi])
# I did this manually, checking the orientation one by one.
faces = [(0, 1, 2),
(0, 6, 1),
(8, 0, 2),
(8, 4, 0),
(3, 8, 2),
(3, 2, 7),
(7, 2, 1),
(0, 4, 6),
(4, 11, 6),
(6, 11, 5),
(1, 5, 7),
(4, 10, 11),
(4, 8, 10),
(10, 8, 3),
(10, 3, 9),
(11, 10, 9),
(11, 9, 5),
(5, 9, 7),
(9, 3, 7),
(1, 6, 5),
]
# By default the top is an aris parallel to the Y axis.
# Need to rotate around the y axis by half the supplementary to the
# angle between faces divided by two to get the desired orientation.
# /O\ (top arist)
# / \ Z
# (adjacent face)/ \ (adjacent face) ^
# / angle_between_faces \ |
# / \ |
# / \ YO-----> X
# This results in:
# (adjacent faceis now top plane)
# ----------------------O\ (top arist)
# \
# \
# \ (adjacent face)
# \
# \
# \
angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
rotation_angle = (np.pi - angle_between_faces) / 2
rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
rotation_matrix = rotation.as_matrix()
vertices = np.dot(vertices, rotation_matrix)
return TriangularMesh(vertices=vertices.astype(np.float32),
faces=np.array(faces, dtype=np.int32))
def _two_split_unit_sphere_triangle_faces(
triangular_mesh: TriangularMesh) -> TriangularMesh:
"""Splits each triangular face into 4 triangles keeping the orientation."""
# Every time we split a triangle into 4 we will be adding 3 extra vertices,
# located at the edge centres.
# This class handles the positioning of the new vertices, and avoids creating
# duplicates.
new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
new_faces = []
for ind1, ind2, ind3 in triangular_mesh.faces:
# Transform each triangular face into 4 triangles,
# preserving the orientation.
# ind3
# / \
# / \
# / #3 \
# / \
# ind31 -------------- ind23
# / \ / \
# / \ #4 / \
# / #1 \ / #2 \
# / \ / \
# ind1 ------------ ind12 ------------ ind2
ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
# Note how each of the 4 triangular new faces specifies the order of the
# vertices to preserve the orientation of the original face. As the input
# face should always be counter-clockwise as specified in the diagram,
# this means child faces should also be counter-clockwise.
new_faces.extend([[ind1, ind12, ind31], # 1
[ind12, ind2, ind23], # 2
[ind31, ind23, ind3], # 3
[ind12, ind23, ind31], # 4
])
return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
faces=np.array(new_faces, dtype=np.int32))
class _ChildVerticesBuilder(object):
"""Bookkeeping of new child vertices added to an existing set of vertices."""
def __init__(self, parent_vertices):
# Because the same new vertex will be required when splitting adjacent
# triangles (which share an edge) we keep them in a hash table indexed by
# sorted indices of the vertices adjacent to the edge, to avoid creating
# duplicated child vertices.
self._child_vertices_index_mapping = {}
self._parent_vertices = parent_vertices
# We start with all previous vertices.
self._all_vertices_list = list(parent_vertices)
def _get_child_vertex_key(self, parent_vertex_indices):
return tuple(sorted(parent_vertex_indices))
def _create_child_vertex(self, parent_vertex_indices):
"""Creates a new vertex."""
# Position for new vertex is the middle point, between the parent points,
# projected to unit sphere.
child_vertex_position = self._parent_vertices[
list(parent_vertex_indices)].mean(0)
child_vertex_position /= np.linalg.norm(child_vertex_position)
# Add the vertex to the output list. The index for this new vertex will
# match the length of the list before adding it.
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
self._child_vertices_index_mapping[child_vertex_key] = len(
self._all_vertices_list)
self._all_vertices_list.append(child_vertex_position)
def get_new_child_vertex_index(self, parent_vertex_indices):
"""Returns index for a child vertex, creating it if necessary."""
# Get the key to see if we already have a new vertex in the middle.
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
if child_vertex_key not in self._child_vertices_index_mapping:
self._create_child_vertex(parent_vertex_indices)
return self._child_vertices_index_mapping[child_vertex_key]
def get_all_vertices(self):
"""Returns an array with old vertices."""
return np.array(self._all_vertices_list)
def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Transforms polygonal faces to sender and receiver indices.
It does so by transforming every face into N_i edges. Such if the triangular
face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
If all faces have consistent orientation, and the surface represented by the
faces is closed, then every edge in a polygon with a certain orientation
is also part of another polygon with the opposite orientation. In this
situation, the edges returned by the method are always bidirectional.
Args:
faces: Integer array of shape [num_faces, 3]. Contains node indices
adjacent to each face.
Returns:
Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
"""
assert faces.ndim == 2
assert faces.shape[-1] == 3
senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
return senders, receivers
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for icosahedral_mesh."""
from absl.testing import absltest
from absl.testing import parameterized
import chex
from graphcast import icosahedral_mesh
import numpy as np
def _get_mesh_spec(splits: int):
"""Returns size of the final icosahedral mesh resulting from the splitting."""
num_vertices = 12
num_faces = 20
for _ in range(splits):
# Each previous face adds three new vertices, but each vertex is shared
# by two faces.
num_vertices += num_faces * 3 // 2
num_faces *= 4
return num_vertices, num_faces
class IcosahedralMeshTest(parameterized.TestCase):
def test_icosahedron(self):
mesh = icosahedral_mesh.get_icosahedron()
_assert_valid_mesh(
mesh, num_expected_vertices=12, num_expected_faces=20)
@parameterized.parameters(list(range(5)))
def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=splits)
prev_vertices = None
for mesh_i, mesh in enumerate(meshes):
# Check that `mesh` is valid.
num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
_assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
# Check that the first N vertices from this mesh match all of the
# vertices from the previous mesh.
if prev_vertices is not None:
leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
# Increase the expected/previous values for the next iteration.
if mesh_i < len(meshes) - 1:
prev_vertices = mesh.vertices
@parameterized.parameters(list(range(4)))
def test_merge_meshes(self, splits):
mesh_hierarchy = (
icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=splits))
mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
np.testing.assert_array_equal(mesh.faces, expected_faces)
def test_faces_to_edges(self):
faces = np.array([[0, 1, 2],
[3, 4, 5]])
# This also documents the order of the edges returned by the method.
expected_edges = np.array(
[[0, 1],
[3, 4],
[1, 2],
[4, 5],
[2, 0],
[5, 3]])
expected_senders = expected_edges[:, 0]
expected_receivers = expected_edges[:, 1]
senders, receivers = icosahedral_mesh.faces_to_edges(faces)
np.testing.assert_array_equal(senders, expected_senders)
np.testing.assert_array_equal(receivers, expected_receivers)
def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
vertices = mesh.vertices
faces = mesh.faces
chex.assert_shape(vertices, [num_expected_vertices, 3])
chex.assert_shape(faces, [num_expected_faces, 3])
# Vertices norm should be 1.
vertices_norm = np.linalg.norm(vertices, axis=-1)
np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
_assert_positive_face_orientation(vertices, faces)
def _assert_positive_face_orientation(vertices, faces):
# Obtain a unit vector that points, in the direction of the face.
face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
vertices[faces[:, 2]] - vertices[faces[:, 1]])
face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
# And a unit vector pointing from the origin to the center of the face.
face_centers = vertices[faces].mean(1)
face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
# Positive orientation means those two vectors should be parallel
# (dot product, 1), and not anti-parallel (dot product, -1).
dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
# Check that the face normal is parallel to the vector that joins the center
# of the face to the center of the sphere. Note we need a small tolerance
# because some discretizations are not exactly uniform, so it will not be
# exactly parallel.
np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
if __name__ == "__main__":
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions (and terms for use in loss functions) used for weather."""
from typing import Mapping
from graphcast import xarray_tree
import numpy as np
from typing_extensions import Protocol
import xarray
LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
class LossFunction(Protocol):
"""A loss function.
This is a protocol so it's fine to use a plain function which 'quacks like'
this. This is just to document the interface.
"""
def __call__(self,
predictions: xarray.Dataset,
targets: xarray.Dataset,
**optional_kwargs) -> LossAndDiagnostics:
"""Computes a loss function.
Args:
predictions: Dataset of predictions.
targets: Dataset of targets.
**optional_kwargs: Implementations may support extra optional kwargs.
Returns:
loss: A DataArray with dimensions ('batch',) containing losses for each
element of the batch. These will be averaged to give the final
loss, locally and across replicas.
diagnostics: Mapping of additional quantities to log by name alongside the
loss. These will will typically correspond to terms in the loss. They
should also have dimensions ('batch',) and will be averaged over the
batch before logging.
"""
def weighted_mse_per_level(
predictions: xarray.Dataset,
targets: xarray.Dataset,
per_variable_weights: Mapping[str, float],
) -> LossAndDiagnostics:
"""Latitude- and pressure-level-weighted MSE loss."""
def loss(prediction, target):
loss = (prediction - target)**2
loss *= normalized_latitude_weights(target).astype(loss.dtype)
if 'level' in target.dims:
loss *= normalized_level_weights(target).astype(loss.dtype)
return _mean_preserving_batch(loss)
losses = xarray_tree.map_structure(loss, predictions, targets)
return sum_per_variable_losses(losses, per_variable_weights)
def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
def sum_per_variable_losses(
per_variable_losses: Mapping[str, xarray.DataArray],
weights: Mapping[str, float],
) -> LossAndDiagnostics:
"""Weighted sum of per-variable losses."""
if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
raise ValueError(
'Passing a weight that does not correspond to any variable '
f'{set(weights.keys())-set(per_variable_losses.keys())}')
weighted_per_variable_losses = {
name: loss * weights.get(name, 1)
for name, loss in per_variable_losses.items()
}
total = xarray.concat(
weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
'variable', skipna=False)
return total, per_variable_losses # pytype: disable=bad-return-type
def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
"""Weights proportional to pressure at each level."""
level = data.coords['level']
return level / level.mean(skipna=False)
def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
"""Weights based on latitude, roughly proportional to grid cell area.
This method supports two use cases only (both for equispaced values):
* Latitude values such that the closest value to the pole is at latitude
(90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
In this case each point with `lat` value represents a sphere slice between
`lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
proportional to:
`sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
that cancels during normalization.
* Latitude values that fall exactly at the poles.
For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
In this case each point with `lat` value also represents
a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
except for the points at the poles, that represent a slice between
`90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
The areas of the first type of point are still proportional to:
* sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
but for the points at the poles now is:
* sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
and we will be using these weights, depending on whether we are looking at
pole cells, or non-pole cells (omitting the common factor of 2 which will be
absorbed by the normalization).
It can be shown via a limit, or simple geometry, that in the small angles
regime, the proportion of area per pole-point is equal to 1/8th
the proportion of area covered by each of the nearest non-pole point, and we
test for this in the test.
Args:
data: `DataArray` with latitude coordinates.
Returns:
Unit mean latitude weights.
"""
latitude = data.coords['lat']
if np.any(np.isclose(np.abs(latitude), 90.)):
weights = _weight_for_latitude_vector_with_poles(latitude)
else:
weights = _weight_for_latitude_vector_without_poles(latitude)
return weights / weights.mean(skipna=False)
def _weight_for_latitude_vector_without_poles(latitude):
"""Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
raise ValueError(
f'Latitude vector {latitude} does not start/end at '
'+- (90 - delta_latitude/2) degrees.')
return np.cos(np.deg2rad(latitude))
def _weight_for_latitude_vector_with_poles(latitude):
"""Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
if (not np.isclose(np.max(latitude), 90.) or
not np.isclose(np.min(latitude), -90.)):
raise ValueError(
f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
# The two checks above enough to guarantee that latitudes are sorted, so
# the extremes are the poles
weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
return weights
def _check_uniform_spacing_and_get_delta(vector):
diff = np.diff(vector)
if not np.all(np.isclose(diff[0], diff)):
raise ValueError(f'Vector {diff} is not uniformly spaced.')
return diff[0]
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for building models."""
from typing import Mapping, Optional, Tuple
import numpy as np
from scipy.spatial import transform
import xarray
def get_graph_spatial_features(
*, node_lat: np.ndarray, node_lon: np.ndarray,
senders: np.ndarray, receivers: np.ndarray,
add_node_positions: bool,
add_node_latitude: bool,
add_node_longitude: bool,
add_relative_positions: bool,
relative_longitude_local_coordinates: bool,
relative_latitude_local_coordinates: bool,
sine_cosine_encoding: bool = False,
encoding_num_freqs: int = 10,
encoding_multiplicative_factor: float = 1.2,
) -> Tuple[np.ndarray, np.ndarray]:
"""Computes spatial features for the nodes.
Args:
node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
senders: Sender indices of shape [num_edges]
receivers: Receiver indices of shape [num_edges]
add_node_positions: Add unit norm absolute positions.
add_node_latitude: Add a feature for latitude (cos(90 - lat))
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
`relative_latitude_local_coordinates` is also True, or if there is any
bias on the relative edge sizes for different longitudes.
add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
`relative_longitude_local_coordinates` is also True, or if there is any
bias on the relative edge sizes for different longitudes.
add_relative_positions: Whether to relative positions in R3 to the edges.
relative_longitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 longitude.
relative_latitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 latitude.
sine_cosine_encoding: If True, we will transform the node/edge features
with sine and cosine functions, similar to NERF.
encoding_num_freqs: frequency parameter
encoding_multiplicative_factor: used for calculating the frequency.
Returns:
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
with node and edge features.
"""
num_nodes = node_lat.shape[0]
num_edges = senders.shape[0]
dtype = node_lat.dtype
node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
# Computing some node features.
node_features = []
if add_node_positions:
# Already in [-1, 1.] range.
node_features.extend(spherical_to_cartesian(node_phi, node_theta))
if add_node_latitude:
# Using the cos of theta.
# From 1. (north pole) to -1 (south pole).
node_features.append(np.cos(node_theta))
if add_node_longitude:
# Using the cos and sin, which is already normalized.
node_features.append(np.cos(node_phi))
node_features.append(np.sin(node_phi))
if not node_features:
node_features = np.zeros([num_nodes, 0], dtype=dtype)
else:
node_features = np.stack(node_features, axis=-1)
# Computing some edge features.
edge_features = []
if add_relative_positions:
relative_position = get_relative_position_in_receiver_local_coordinates(
node_phi=node_phi,
node_theta=node_theta,
senders=senders,
receivers=receivers,
latitude_local_coordinates=relative_latitude_local_coordinates,
longitude_local_coordinates=relative_longitude_local_coordinates
)
# Note this is L2 distance in 3d space, rather than geodesic distance.
relative_edge_distances = np.linalg.norm(
relative_position, axis=-1, keepdims=True)
# Normalize to the maximum edge distance. Note that we expect to always
# have an edge that goes in the opposite direction of any given edge
# so the distribution of relative positions should be symmetric around
# zero. So by scaling by the maximum length, we expect all relative
# positions to fall in the [-1., 1.] interval, and all relative distances
# to fall in the [0., 1.] interval.
max_edge_distance = relative_edge_distances.max()
edge_features.append(relative_edge_distances / max_edge_distance)
edge_features.append(relative_position / max_edge_distance)
if not edge_features:
edge_features = np.zeros([num_edges, 0], dtype=dtype)
else:
edge_features = np.concatenate(edge_features, axis=-1)
if sine_cosine_encoding:
def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
phases = freqs * x[..., None]
x_sin = np.sin(phases)
x_cos = np.cos(phases)
x_cat = np.concatenate([x_sin, x_cos], axis=-1)
return x_cat.reshape([x.shape[0], -1])
node_features = sine_cosine_transform(node_features)
edge_features = sine_cosine_transform(edge_features)
return node_features, edge_features
def lat_lon_to_leading_axes(
grid_xarray: xarray.DataArray) -> xarray.DataArray:
"""Reorders xarray so lat/lon axes come first."""
# leading + ["lat", "lon"] + trailing
# to
# ["lat", "lon"] + leading + trailing
return grid_xarray.transpose("lat", "lon", ...)
def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
"""Reorders xarray so batch/time/level axes come first (if present)."""
# ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
# to
# [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
input_dims = list(grid_xarray.dims)
output_dims = list(input_dims)
for leading_key in ["level", "time", "batch"]: # reverse order for insert
if leading_key in input_dims:
output_dims.remove(leading_key)
output_dims.insert(0, leading_key)
return grid_xarray.transpose(*output_dims)
def lat_lon_deg_to_spherical(node_lat: np.ndarray,
node_lon: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
phi = np.deg2rad(node_lon)
theta = np.deg2rad(90 - node_lat)
return phi, theta
def spherical_to_lat_lon(phi: np.ndarray,
theta: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
lon = np.mod(np.rad2deg(phi), 360)
lat = 90 - np.rad2deg(theta)
return lat, lon
def cartesian_to_spherical(x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
phi = np.arctan2(y, x)
with np.errstate(invalid="ignore"): # circumventing b/253179568
theta = np.arccos(z) # Assuming unit radius.
return phi, theta
def spherical_to_cartesian(
phi: np.ndarray, theta: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
# Assuming unit radius.
return (np.cos(phi)*np.sin(theta),
np.sin(phi)*np.sin(theta),
np.cos(theta))
def get_relative_position_in_receiver_local_coordinates(
node_phi: np.ndarray,
node_theta: np.ndarray,
senders: np.ndarray,
receivers: np.ndarray,
latitude_local_coordinates: bool,
longitude_local_coordinates: bool
) -> np.ndarray:
"""Returns relative position features for the edges.
The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
that local coordinate system after the rotation in R^3.
Args:
node_phi: [num_nodes] with polar angles.
node_theta: [num_nodes] with azimuthal angles.
senders: [num_edges] with indices.
receivers: [num_edges] with indices.
latitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at latitude 0.
longitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at longitude 0.
Returns:
Array of relative positions in R3 [num_edges, 3]
"""
node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
# No rotation in this case.
if not (latitude_local_coordinates or longitude_local_coordinates):
return node_pos[senders] - node_pos[receivers]
# Get rotation matrices for the local space space for every node.
rotation_matrices = get_rotation_matrices_to_local_coordinates(
reference_phi=node_phi,
reference_theta=node_theta,
rotate_latitude=latitude_local_coordinates,
rotate_longitude=longitude_local_coordinates)
# Each edge will be rotated according to the rotation matrix of its receiver
# node.
edge_rotation_matrices = rotation_matrices[receivers]
# Rotate all nodes to the rotated space of the corresponding edge.
# Note for receivers we can also do the matmul first and the gather second:
# ```
# receiver_pos_in_rotated_space = rotate_with_matrices(
# rotation_matrices, node_pos)[receivers]
# ```
# which is more efficient, however, we do gather first to keep it more
# symmetric with the sender computation.
receiver_pos_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, node_pos[receivers])
sender_pos_in_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, node_pos[senders])
# Note, here, that because the rotated space is chosen according to the
# receiver, if:
# * latitude_local_coordinates = True: latitude for the receivers will be
# 0, that is the z coordinate will always be 0.
# * longitude_local_coordinates = True: longitude for the receivers will be
# 0, that is the y coordinate will be 0.
# Now we can just subtract.
# Note we are rotating to a local coordinate system, where the y-z axes are
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
# Note that if both `latitude_local_coordinates` and
# `longitude_local_coordinates` are True, and edges are short,
# then the difference in x coordinate between sender and receiver
# should be small, so we could consider dropping the new x coordinate if
# we wanted to the tangent plane, however in doing so
# we would lose information about the curvature of the mesh, which may be
# important for very coarse meshes.
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
def get_rotation_matrices_to_local_coordinates(
reference_phi: np.ndarray,
reference_theta: np.ndarray,
rotate_latitude: bool,
rotate_longitude: bool) -> np.ndarray:
"""Returns a rotation matrix to rotate to a point based on a reference vector.
The rotation matrix is build such that, a vector in the
same coordinate system at the reference point that points towards the pole
before the rotation, continues to point towards the pole after the rotation.
Args:
reference_phi: [leading_axis] Polar angles of the reference.
reference_theta: [leading_axis] Azimuthal angles of the reference.
rotate_latitude: Whether to produce a rotation matrix that would rotate
R^3 vectors to zero latitude.
rotate_longitude: Whether to produce a rotation matrix that would rotate
R^3 vectors to zero longitude.
Returns:
Matrices of shape [leading_axis] such that when applied to the reference
position with `rotate_with_matrices(rotation_matrices, reference_pos)`
* phi goes to 0. if "rotate_longitude" is True.
* theta goes to np.pi / 2 if "rotate_latitude" is True.
The rotation consists of:
* rotate_latitude = False, rotate_longitude = True:
Latitude preserving rotation.
* rotate_latitude = True, rotate_longitude = True:
Latitude preserving rotation, followed by longitude preserving
rotation.
* rotate_latitude = True, rotate_longitude = False:
Latitude preserving rotation, followed by longitude preserving
rotation, and the inverse of the latitude preserving rotation. Note
this is computationally different from rotating the longitude only
and is. We do it like this, so the polar geodesic curve, continues
to be aligned with one of the axis after the rotation.
"""
if rotate_longitude and rotate_latitude:
# We first rotate around the z axis "minus the azimuthal angle", to get the
# point with zero longitude
azimuthal_rotation = - reference_phi
# One then we will do a polar rotation (which can be done along the y
# axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
# to get the point with zero latitude.
polar_rotation = - reference_theta + np.pi/2
return transform.Rotation.from_euler(
"zy", np.stack([azimuthal_rotation, polar_rotation],
axis=1)).as_matrix()
elif rotate_longitude:
# Just like the previous case, but applying only the azimuthal rotation.
azimuthal_rotation = - reference_phi
return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
elif rotate_latitude:
# Just like the first case, but after doing the polar rotation, undoing
# the azimuthal rotation.
azimuthal_rotation = - reference_phi
polar_rotation = - reference_theta + np.pi/2
return transform.Rotation.from_euler(
"zyz", np.stack(
[azimuthal_rotation, polar_rotation, -azimuthal_rotation]
, axis=1)).as_matrix()
else:
raise ValueError(
"At least one of longitude and latitude should be rotated.")
def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
) -> np.ndarray:
return np.einsum("bji,bi->bj", rotation_matrices, positions)
def get_bipartite_graph_spatial_features(
*,
senders_node_lat: np.ndarray,
senders_node_lon: np.ndarray,
senders: np.ndarray,
receivers_node_lat: np.ndarray,
receivers_node_lon: np.ndarray,
receivers: np.ndarray,
add_node_positions: bool,
add_node_latitude: bool,
add_node_longitude: bool,
add_relative_positions: bool,
edge_normalization_factor: Optional[float] = None,
relative_longitude_local_coordinates: bool,
relative_latitude_local_coordinates: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Computes spatial features for the nodes.
This function is almost identical to `get_graph_spatial_features`. The only
difference is that sender nodes and receiver nodes can be in different arrays.
This is necessary to enable combination with typed Graph.
Args:
senders_node_lat: Latitudes in the [-90, 90] interval of shape
[num_sender_nodes]
senders_node_lon: Longitudes in the [0, 360] interval of shape
[num_sender_nodes]
senders: Sender indices of shape [num_edges], indices in [0,
num_sender_nodes)
receivers_node_lat: Latitudes in the [-90, 90] interval of shape
[num_receiver_nodes]
receivers_node_lon: Longitudes in the [0, 360] interval of shape
[num_receiver_nodes]
receivers: Receiver indices of shape [num_edges], indices in [0,
num_receiver_nodes)
add_node_positions: Add unit norm absolute positions.
add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
this is set to False, the model may be able to infer the longitude from
relative features, unless `relative_latitude_local_coordinates` is also
True, or if there is any bias on the relative edge sizes for different
longitudes.
add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
even if this is set to False, the model may be able to infer the longitude
from relative features, unless `relative_longitude_local_coordinates` is
also True, or if there is any bias on the relative edge sizes for
different longitudes.
add_relative_positions: Whether to relative positions in R3 to the edges.
edge_normalization_factor: Allows explicitly controlling edge normalization.
If None, defaults to max edge length. This supports using pre-trained
model weights with a different graph structure to what it was trained on.
relative_longitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 longitude.
relative_latitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 latitude.
Returns:
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
with node and edge features.
"""
num_senders = senders_node_lat.shape[0]
num_receivers = receivers_node_lat.shape[0]
num_edges = senders.shape[0]
dtype = senders_node_lat.dtype
assert receivers_node_lat.dtype == dtype
senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
senders_node_lat, senders_node_lon)
receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
receivers_node_lat, receivers_node_lon)
# Computing some node features.
senders_node_features = []
receivers_node_features = []
if add_node_positions:
# Already in [-1, 1.] range.
senders_node_features.extend(
spherical_to_cartesian(senders_node_phi, senders_node_theta))
receivers_node_features.extend(
spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
if add_node_latitude:
# Using the cos of theta.
# From 1. (north pole) to -1 (south pole).
senders_node_features.append(np.cos(senders_node_theta))
receivers_node_features.append(np.cos(receivers_node_theta))
if add_node_longitude:
# Using the cos and sin, which is already normalized.
senders_node_features.append(np.cos(senders_node_phi))
senders_node_features.append(np.sin(senders_node_phi))
receivers_node_features.append(np.cos(receivers_node_phi))
receivers_node_features.append(np.sin(receivers_node_phi))
if not senders_node_features:
senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
else:
senders_node_features = np.stack(senders_node_features, axis=-1)
receivers_node_features = np.stack(receivers_node_features, axis=-1)
# Computing some edge features.
edge_features = []
if add_relative_positions:
relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
senders_node_phi=senders_node_phi,
senders_node_theta=senders_node_theta,
receivers_node_phi=receivers_node_phi,
receivers_node_theta=receivers_node_theta,
senders=senders,
receivers=receivers,
latitude_local_coordinates=relative_latitude_local_coordinates,
longitude_local_coordinates=relative_longitude_local_coordinates)
# Note this is L2 distance in 3d space, rather than geodesic distance.
relative_edge_distances = np.linalg.norm(
relative_position, axis=-1, keepdims=True)
if edge_normalization_factor is None:
# Normalize to the maximum edge distance. Note that we expect to always
# have an edge that goes in the opposite direction of any given edge
# so the distribution of relative positions should be symmetric around
# zero. So by scaling by the maximum length, we expect all relative
# positions to fall in the [-1., 1.] interval, and all relative distances
# to fall in the [0., 1.] interval.
edge_normalization_factor = relative_edge_distances.max()
edge_features.append(relative_edge_distances / edge_normalization_factor)
edge_features.append(relative_position / edge_normalization_factor)
if not edge_features:
edge_features = np.zeros([num_edges, 0], dtype=dtype)
else:
edge_features = np.concatenate(edge_features, axis=-1)
return senders_node_features, receivers_node_features, edge_features
def get_bipartite_relative_position_in_receiver_local_coordinates(
senders_node_phi: np.ndarray,
senders_node_theta: np.ndarray,
senders: np.ndarray,
receivers_node_phi: np.ndarray,
receivers_node_theta: np.ndarray,
receivers: np.ndarray,
latitude_local_coordinates: bool,
longitude_local_coordinates: bool) -> np.ndarray:
"""Returns relative position features for the edges.
This function is equivalent to
`get_relative_position_in_receiver_local_coordinates`, but adapted to work
with bipartite typed graphs.
The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
that local coordinate system after the rotation in R^3.
Args:
senders_node_phi: [num_sender_nodes] with polar angles.
senders_node_theta: [num_sender_nodes] with azimuthal angles.
senders: [num_edges] with indices into sender nodes.
receivers_node_phi: [num_sender_nodes] with polar angles.
receivers_node_theta: [num_sender_nodes] with azimuthal angles.
receivers: [num_edges] with indices into receiver nodes.
latitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at latitude 0.
longitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at longitude 0.
Returns:
Array of relative positions in R3 [num_edges, 3]
"""
senders_node_pos = np.stack(
spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
receivers_node_pos = np.stack(
spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
# No rotation in this case.
if not (latitude_local_coordinates or longitude_local_coordinates):
return senders_node_pos[senders] - receivers_node_pos[receivers]
# Get rotation matrices for the local space space for every receiver node.
receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
reference_phi=receivers_node_phi,
reference_theta=receivers_node_theta,
rotate_latitude=latitude_local_coordinates,
rotate_longitude=longitude_local_coordinates)
# Each edge will be rotated according to the rotation matrix of its receiver
# node.
edge_rotation_matrices = receiver_rotation_matrices[receivers]
# Rotate all nodes to the rotated space of the corresponding edge.
# Note for receivers we can also do the matmul first and the gather second:
# ```
# receiver_pos_in_rotated_space = rotate_with_matrices(
# rotation_matrices, node_pos)[receivers]
# ```
# which is more efficient, however, we do gather first to keep it more
# symmetric with the sender computation.
receiver_pos_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, receivers_node_pos[receivers])
sender_pos_in_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, senders_node_pos[senders])
# Note, here, that because the rotated space is chosen according to the
# receiver, if:
# * latitude_local_coordinates = True: latitude for the receivers will be
# 0, that is the z coordinate will always be 0.
# * longitude_local_coordinates = True: longitude for the receivers will be
# 0, that is the y coordinate will be 0.
# Now we can just subtract.
# Note we are rotating to a local coordinate system, where the y-z axes are
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
# Note that if both `latitude_local_coordinates` and
# `longitude_local_coordinates` are True, and edges are short,
# then the difference in x coordinate between sender and receiver
# should be small, so we could consider dropping the new x coordinate if
# we wanted to the tangent plane, however in doing so
# we would lose information about the curvature of the mesh, which may be
# important for very coarse meshes.
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
def variable_to_stacked(
variable: xarray.Variable,
sizes: Mapping[str, int],
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.Variable:
"""Converts an xarray.Variable to preserved_dims + ("channels",).
Any dimensions other than those included in preserved_dims get stacked into a
final "channels" dimension. If any of the preserved_dims are missing then they
are added, with the data broadcast/tiled to match the sizes specified in
`sizes`.
Args:
variable: An xarray.Variable.
sizes: Mapping including sizes for any dimensions which are not present in
`variable` but are needed for the output. This may be needed for example
for a static variable with only ("lat", "lon") dims, or if you want to
encode just the latitude coordinates (a variable with dims ("lat",)).
preserved_dims: dimensions of variable to not be folded in channels.
Returns:
An xarray.Variable with dimensions preserved_dims + ("channels",).
"""
stack_to_channels_dims = [
d for d in variable.dims if d not in preserved_dims]
if stack_to_channels_dims:
variable = variable.stack(channels=stack_to_channels_dims)
dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
dims["channels"] = variable.sizes.get("channels", 1)
return variable.set_dims(dims)
def dataset_to_stacked(
dataset: xarray.Dataset,
sizes: Optional[Mapping[str, int]] = None,
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.DataArray:
"""Converts an xarray.Dataset to a single stacked array.
This takes each consistuent data_var, converts it into BHWC layout
using `variable_to_stacked`, then concats them all along the channels axis.
Args:
dataset: An xarray.Dataset.
sizes: Mapping including sizes for any dimensions which are not present in
the `dataset` but are needed for the output. See variable_to_stacked.
preserved_dims: dimensions from the dataset that should not be folded in
the predictions channels.
Returns:
An xarray.DataArray with dimensions preserved_dims + ("channels",).
Existing coordinates for preserved_dims axes will be preserved, however
there will be no coordinates for "channels".
"""
data_vars = [
variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
preserved_dims)
for name in sorted(dataset.data_vars.keys())
]
coords = {
dim: coord
for dim, coord in dataset.coords.items()
if dim in preserved_dims
}
return xarray.DataArray(
data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
def stacked_to_dataset(
stacked_array: xarray.Variable,
template_dataset: xarray.Dataset,
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.Dataset:
"""The inverse of dataset_to_stacked.
Requires a template dataset to demonstrate the variables/shapes/coordinates
required.
All variables must have preserved_dims dimensions.
Args:
stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
would if it was asked to encode `template_dataset`.
template_dataset: A template Dataset (or other mapping of DataArrays)
demonstrating the shape of output required (variables, shapes,
coordinates etc).
preserved_dims: dimensions from the target_template that were not folded in
the predictions channels. The preserved_dims need to be a subset of the
dims of all the variables of template_dataset.
Returns:
An xarray.Dataset (or other mapping of DataArrays) with the same shape and
type as template_dataset.
"""
unstack_from_channels_sizes = {}
var_names = sorted(template_dataset.keys())
for name in var_names:
template_var = template_dataset[name]
if not all(dim in template_var.dims for dim in preserved_dims):
raise ValueError(
f"stacked_to_dataset requires all Variables to have {preserved_dims} "
f"dimensions, but found only {template_var.dims}.")
unstack_from_channels_sizes[name] = {
dim: size for dim, size in template_var.sizes.items()
if dim not in preserved_dims}
channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
for name, unstack_sizes in unstack_from_channels_sizes.items()}
total_expected_channels = sum(channels.values())
found_channels = stacked_array.sizes["channels"]
if total_expected_channels != found_channels:
raise ValueError(
f"Expected {total_expected_channels} channels but found "
f"{found_channels}, when trying to convert a stacked array of shape "
f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
data_vars = {}
index = 0
for name in var_names:
template_var = template_dataset[name]
var = stacked_array.isel({"channels": slice(index, index + channels[name])})
index += channels[name]
var = var.unstack({"channels": unstack_from_channels_sizes[name]})
var = var.transpose(*template_var.dims)
data_vars[name] = xarray.DataArray(
data=var,
coords=template_var.coords,
# This might not always be the same as the name it's keyed under; it
# will refer to the original variable name, whereas the key might be
# some alias e.g. temperature_850 under which it should be logged:
name=template_var.name,
)
return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers for Predictors which allow them to work with normalized data.
The Predictor which is wrapped sees normalized inputs and targets, and makes
normalized predictions. The wrapper handles translating the predictions back
to the original domain.
"""
import logging
from typing import Optional, Tuple
from graphcast import predictor_base
from graphcast import xarray_tree
import xarray
def normalize(values: xarray.Dataset,
scales: xarray.Dataset,
locations: Optional[xarray.Dataset],
) -> xarray.Dataset:
"""Normalize variables using the given scales and (optionally) locations."""
def normalize_array(array):
if array.name is None:
raise ValueError(
"Can't look up normalization constants because array has no name.")
if locations is not None:
if array.name in locations:
array = array - locations[array.name].astype(array.dtype)
else:
logging.warning('No normalization location found for %s', array.name)
if array.name in scales:
array = array / scales[array.name].astype(array.dtype)
else:
logging.warning('No normalization scale found for %s', array.name)
return array
return xarray_tree.map_structure(normalize_array, values)
def unnormalize(values: xarray.Dataset,
scales: xarray.Dataset,
locations: Optional[xarray.Dataset],
) -> xarray.Dataset:
"""Unnormalize variables using the given scales and (optionally) locations."""
def unnormalize_array(array):
if array.name is None:
raise ValueError(
"Can't look up normalization constants because array has no name.")
if array.name in scales:
array = array * scales[array.name].astype(array.dtype)
else:
logging.warning('No normalization scale found for %s', array.name)
if locations is not None:
if array.name in locations:
array = array + locations[array.name].astype(array.dtype)
else:
logging.warning('No normalization location found for %s', array.name)
return array
return xarray_tree.map_structure(unnormalize_array, values)
class InputsAndResiduals(predictor_base.Predictor):
"""Wraps with a residual connection, normalizing inputs and target residuals.
The inner predictor is given inputs that are normalized using `locations`
and `scales` to roughly zero-mean unit variance.
For target variables that are present in the inputs, the inner predictor is
trained to predict residuals (target - last_frame_of_input) that have been
normalized using `residual_scales` (and optionally `residual_locations`) to
roughly unit variance / zero mean.
This replaces `residual.Predictor` in the case where you want normalization
that's based on the scales of the residuals.
Since we return the underlying predictor's loss on the normalized residuals,
if the underlying predictor is a sum of per-variable losses, the normalization
will affect the relative weighting of the per-variable loss terms (hopefully
in a good way).
For target variables *not* present in the inputs, the inner predictor is
trained to predict targets directly, that have been normalized in the same
way as the inputs.
The transforms applied to the targets (the residual connection and the
normalization) are applied in reverse to the predictions before returning
them.
"""
def __init__(
self,
predictor: predictor_base.Predictor,
stddev_by_level: xarray.Dataset,
mean_by_level: xarray.Dataset,
diffs_stddev_by_level: xarray.Dataset):
self._predictor = predictor
self._scales = stddev_by_level
self._locations = mean_by_level
self._residual_scales = diffs_stddev_by_level
self._residual_locations = None
def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
if norm_prediction.sizes.get('time') != 1:
raise ValueError(
'normalization.InputsAndResiduals only supports predicting a '
'single timestep.')
if norm_prediction.name in inputs:
# Residuals are assumed to be predicted as normalized (unit variance),
# but the scale and location they need mapping to is that of the residuals
# not of the values themselves.
prediction = unnormalize(
norm_prediction, self._residual_scales, self._residual_locations)
# A prediction for which we have a corresponding input -- we are
# predicting the residual:
last_input = inputs[norm_prediction.name].isel(time=-1)
prediction = prediction + last_input
return prediction
else:
# A predicted variable which is not an input variable. We are predicting
# it directly, so unnormalize it directly to the target scale/location:
return unnormalize(norm_prediction, self._scales, self._locations)
def _subtract_input_and_normalize_target(self, inputs, target):
if target.sizes.get('time') != 1:
raise ValueError(
'normalization.InputsAndResiduals only supports wrapping predictors'
'that predict a single timestep.')
if target.name in inputs:
target_residual = target
last_input = inputs[target.name].isel(time=-1)
target_residual = target_residual - last_input
return normalize(
target_residual, self._residual_scales, self._residual_locations)
else:
return normalize(target, self._scales, self._locations)
def __call__(self,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs
) -> xarray.Dataset:
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_predictions = self._predictor(
norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
return xarray_tree.map_structure(
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
norm_predictions)
def loss(self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> predictor_base.LossAndDiagnostics:
"""Returns the loss computed on normalized inputs and targets."""
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_target_residuals = xarray_tree.map_structure(
lambda t: self._subtract_input_and_normalize_target(inputs, t),
targets)
return self._predictor.loss(
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> Tuple[predictor_base.LossAndDiagnostics,
xarray.Dataset]:
"""The loss computed on normalized data, with unnormalized predictions."""
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_target_residuals = xarray_tree.map_structure(
lambda t: self._subtract_input_and_normalize_target(inputs, t),
targets)
(loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
predictions = xarray_tree.map_structure(
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
norm_predictions)
return (loss, scalars), predictions
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for an xarray-based Predictor API."""
import abc
from typing import Tuple
from graphcast import losses
from graphcast import xarray_jax
import jax.numpy as jnp
import xarray
LossAndDiagnostics = losses.LossAndDiagnostics
class Predictor(abc.ABC):
"""A possibly-trainable predictor of weather, exposing an xarray-based API.
Typically wraps an underlying JAX model and handles translating the xarray
Dataset values to and from plain JAX arrays that are convenient for input to
(and output from) the underlying model.
Different subclasses may exist to wrap different kinds of underlying model,
e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
inputs/outputs, autoregressive models.
You can also implement a specific model directly as a Predictor if you want,
for example if it has quite specific/unique requirements for its input/output
or loss function, or if it's convenient to implement directly using xarray.
"""
@abc.abstractmethod
def __call__(self,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**optional_kwargs
) -> xarray.Dataset:
"""Makes predictions.
This is only used by the Experiment for inference / evaluation, with
training going via the .loss method. So it should default to making
predictions for evaluation, although you can also support making predictions
for use in the loss via an is_training argument -- see
LossFunctionPredictor which helps with that.
Args:
inputs: An xarray.Dataset of inputs.
targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
with the same shape as the targets, to demonstrate what kind of
predictions are required. You can use this to determine which variables,
levels and lead times must be predicted.
You are free to raise an error if you don't support predicting what is
requested.
forcings: An xarray.Dataset of forcings terms. Forcings are variables
that can be fed to the model, but do not need to be predicted. This is
often because this variable can be computed analytically (e.g. the toa
radiation of the sun is mostly a function of geometry) or are considered
to be controlled for the experiment (e.g., impose a scenario of C02
emission into the atmosphere). Unlike `inputs`, the `forcings` can
include information "from the future", that is, information at target
times specified in the `targets_template`.
**optional_kwargs: Implementations may support extra optional kwargs,
provided they set appropriate defaults for them.
Returns:
Predictions, as an xarray.Dataset or other mapping of DataArrays which
is capable of being evaluated against targets with shape given by
targets_template.
For probabilistic predictors which can return multiple samples from a
predictive distribution, these should (by convention) be returned along
an additional 'sample' dimension.
"""
def loss(self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**optional_kwargs,
) -> LossAndDiagnostics:
"""Computes a training loss, for predictors that are trainable.
Why make this the Predictor's responsibility, rather than letting callers
compute their own loss function using predictions obtained from
Predictor.__call__?
Doing it this way gives Predictors more control over their training setup.
For example, some predictors may wish to train using different targets to
the ones they predict at evaluation time -- perhaps different lead times and
variables, perhaps training to predict transformed versions of targets
where the transform needs to be inverted at evaluation time, etc.
It's also necessary for generative models (VAEs, GANs, ...) where the
training loss is more complex and isn't expressible as a parameter-free
function of predictions and targets.
Args:
inputs: An xarray.Dataset.
targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
docs on __call__ for an explanation about the targets.
forcings: xarray.Dataset of forcing terms.
**optional_kwargs: Implementations may support extra optional kwargs,
provided they set appropriate defaults for them.
Returns:
loss: A DataArray with dimensions ('batch',) containing losses for each
element of the batch. These will be averaged to give the final
loss, locally and across replicas.
diagnostics: Mapping of additional quantities to log by name alongside the
loss. These will will typically correspond to terms in the loss. They
should also have dimensions ('batch',) and will be averaged over the
batch before logging.
You need not include the loss itself in this dict; it will be added for
you.
"""
del targets, forcings, optional_kwargs
batch_size = inputs.sizes['batch']
dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
return dummy_loss, {} # pytype: disable=bad-return-type
def loss_and_predictions(
self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**optional_kwargs,
) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
"""Like .loss but also returns corresponding predictions.
Implementing this is optional as it's not used directly by the Experiment,
but it is required by autoregressive.Predictor when applying an inner
Predictor autoregressively at training time; we need a loss at each step but
also predictions to feed back in for the next step.
Note the loss itself may not be directly regressing the predictions towards
targets, the loss may be computed in terms of transformed predictions and
targets (or in some other way). For this reason we can't always cleanly
separate this into step 1: get predictions, step 2: compute loss from them,
hence the need for this combined method.
Args:
inputs:
targets:
forcings:
**optional_kwargs:
As for self.loss.
Returns:
(loss, diagnostics)
As for self.loss
predictions:
The predictions which the loss relates to. These should be of the same
shape as what you would get from
`self.__call__(inputs, targets_template=targets)`, and should be in the
same 'domain' as the inputs (i.e. they shouldn't be transformed
differently to how the predictor expects its inputs).
"""
raise NotImplementedError
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for rolling out models."""
from typing import Iterator
from absl import logging
import chex
import dask.array
from graphcast import xarray_tree
import jax
import numpy as np
import typing_extensions
import xarray
class PredictorFn(typing_extensions.Protocol):
"""Functional version of base.Predictor.__call__ with explicit rng."""
def __call__(
self, rng: chex.PRNGKey, inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**optional_kwargs,
) -> xarray.Dataset:
...
def chunked_prediction(
predictor_fn: PredictorFn,
rng: chex.PRNGKey,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
num_steps_per_chunk: int = 1,
verbose: bool = False,
) -> xarray.Dataset:
"""Outputs a long trajectory by iteratively concatenating chunked predictions.
Args:
predictor_fn: Function to use to make predictions for each chunk.
rng: Random key.
inputs: Inputs for the model.
targets_template: Template for the target prediction, requires targets
equispaced in time.
forcings: Optional forcing for the model.
num_steps_per_chunk: How many of the steps in `targets_template` to predict
at each call of `predictor_fn`. It must evenly divide the number of
steps in `targets_template`.
verbose: Whether to log the current chunk being predicted.
Returns:
Predictions for the targets template.
"""
chunks_list = []
for prediction_chunk in chunked_prediction_generator(
predictor_fn=predictor_fn,
rng=rng,
inputs=inputs,
targets_template=targets_template,
forcings=forcings,
num_steps_per_chunk=num_steps_per_chunk,
verbose=verbose):
chunks_list.append(jax.device_get(prediction_chunk))
return xarray.concat(chunks_list, dim="time")
def chunked_prediction_generator(
predictor_fn: PredictorFn,
rng: chex.PRNGKey,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
num_steps_per_chunk: int = 1,
verbose: bool = False,
) -> Iterator[xarray.Dataset]:
"""Outputs a long trajectory by yielding chunked predictions.
Args:
predictor_fn: Function to use to make predictions for each chunk.
rng: Random key.
inputs: Inputs for the model.
targets_template: Template for the target prediction, requires targets
equispaced in time.
forcings: Optional forcing for the model.
num_steps_per_chunk: How many of the steps in `targets_template` to predict
at each call of `predictor_fn`. It must evenly divide the number of
steps in `targets_template`.
verbose: Whether to log the current chunk being predicted.
Yields:
The predictions for each chunked step of the chunked rollout, such as
if all predictions are concatenated in time this would match the targets
template in structure.
"""
# Create copies to avoid mutating inputs.
inputs = xarray.Dataset(inputs)
targets_template = xarray.Dataset(targets_template)
forcings = xarray.Dataset(forcings)
if "datetime" in inputs.coords:
del inputs.coords["datetime"]
if "datetime" in targets_template.coords:
output_datetime = targets_template.coords["datetime"]
del targets_template.coords["datetime"]
else:
output_datetime = None
if "datetime" in forcings.coords:
del forcings.coords["datetime"]
num_target_steps = targets_template.dims["time"]
num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
if remainder != 0:
raise ValueError(
f"The number of steps per chunk {num_steps_per_chunk} must "
f"evenly divide the number of target steps {num_target_steps} ")
if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
raise ValueError("The targets time coordinates must be evenly spaced")
# Our template targets will always have a time axis corresponding for the
# timedeltas for the first chunk.
targets_chunk_time = targets_template.time.isel(
time=slice(0, num_steps_per_chunk))
current_inputs = inputs
for chunk_index in range(num_chunks):
if verbose:
logging.info("Chunk %d/%d", chunk_index, num_chunks)
logging.flush()
# Select targets for the time period that we are predicting for this chunk.
target_offset = num_steps_per_chunk * chunk_index
target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
current_targets_template = targets_template.isel(time=target_slice)
# Replace the timedelta, by the one corresponding to the first chunk, so we
# don't recompile at every iteration, keeping the
actual_target_time = current_targets_template.coords["time"]
current_targets_template = current_targets_template.assign_coords(
time=targets_chunk_time).compute()
current_forcings = forcings.isel(time=target_slice)
current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
current_forcings = current_forcings.compute()
# Make predictions for the chunk.
rng, this_rng = jax.random.split(rng)
predictions = predictor_fn(
rng=this_rng,
inputs=current_inputs,
targets_template=current_targets_template,
forcings=current_forcings)
next_frame = xarray.merge([predictions, current_forcings])
next_inputs = _get_next_inputs(current_inputs, next_frame)
# Shift timedelta coordinates, so we don't recompile at every iteration.
next_inputs = next_inputs.assign_coords(time=current_inputs.coords["time"])
current_inputs = next_inputs
# At this point we can assign the actual targets time coordinates.
predictions = predictions.assign_coords(time=actual_target_time)
if output_datetime is not None:
predictions.coords["datetime"] = output_datetime.isel(
time=target_slice)
yield predictions
del predictions
def _get_next_inputs(
prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
) -> xarray.Dataset:
"""Computes next inputs, from previous inputs and predictions."""
# Make sure are are predicting all inputs with a time axis.
non_predicted_or_forced_inputs = list(
set(prev_inputs.keys()) - set(next_frame.keys()))
if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
raise ValueError(
"Found an input with a time index that is not predicted or forced.")
# Keys we need to copy from predictions to inputs.
next_inputs_keys = list(
set(next_frame.keys()).intersection(set(prev_inputs.keys())))
next_inputs = next_frame[next_inputs_keys]
# Apply concatenate next frame with inputs, crop what we don't need.
num_inputs = prev_inputs.dims["time"]
return (
xarray.concat(
[prev_inputs, next_inputs], dim="time", data_vars="different")
.tail(time=num_inputs))
def extend_targets_template(
targets_template: xarray.Dataset,
required_num_steps: int) -> xarray.Dataset:
"""Extends `targets_template` to `required_num_steps` with lazy arrays.
It uses lazy dask arrays of zeros, so it does not require instantiating the
array in memory.
Args:
targets_template: Input template to extend.
required_num_steps: Number of steps required in the returned template.
Returns:
`xarray.Dataset` identical in variables and timestep to `targets_template`
full of `dask.array.zeros` such that the time axis has `required_num_steps`.
"""
# Extend the "time" and "datetime" coordinates
time = targets_template.coords["time"]
# Assert the first target time corresponds to the timestep.
timestep = time[0].data
if time.shape[0] > 1:
assert np.all(timestep == time[1:] - time[:-1])
extended_time = (np.arange(required_num_steps) + 1) * timestep
if "datetime" in targets_template.coords:
datetime = targets_template.coords["datetime"]
extended_datetime = (datetime[0].data - timestep) + extended_time
else:
extended_datetime = None
# Replace the values with empty dask arrays extending the time coordinates.
datetime = targets_template.coords["time"]
def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
dims = data_array.dims
shape = list(data_array.shape)
shape[dims.index("time")] = required_num_steps
dask_data = dask.array.zeros(
shape=tuple(shape),
chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
dtype=data_array.dtype)
coords = dict(data_array.coords)
coords["time"] = extended_time
if extended_datetime is not None:
coords["datetime"] = ("time", extended_datetime)
return xarray.DataArray(
dims=dims,
data=dask_data,
coords=coords)
return xarray_tree.map_structure(extend_time, targets_template)
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computes TOA incident solar radiation compatible with ERA5.
The Top-Of-the-Atmosphere (TOA) incident solar radiation is available in the
ERA5 dataset as the parameter `toa_incident_solar_radiation` (or `tisr`). This
represents the TOA solar radiation flux integrated over a period of one hour
ending at the timestamp given by the `datetime` coordinate. See
https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
https://codes.ecmwf.int/grib/param-db/?id=212.
"""
from collections.abc import Callable, Sequence
import dataclasses
import functools
import chex
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import xarray as xa
# Default value of the `integration_period` argument to be compatible with ERA5.
_DEFAULT_INTEGRATION_PERIOD = pd.Timedelta(hours=1)
# Default value for the `num_integration_bins` argument. This provides a good
# approximation of the solar radiation in ERA5.
_DEFAULT_NUM_INTEGRATION_BINS = 360
# The length of a Julian year in days.
# https://en.wikipedia.org/wiki/Julian_year_(astronomy)
_JULIAN_YEAR_LENGTH_IN_DAYS = 365.25
# Julian Date for the J2000 epoch, a standard reference used in astronomy.
# https://en.wikipedia.org/wiki/Epoch_(astronomy)#Julian_years_and_J2000
_J2000_EPOCH = 2451545.0
# Number of seconds in a day.
_SECONDS_PER_DAY = 60 * 60 * 24
_TimestampLike = str | pd.Timestamp | np.datetime64
_TimedeltaLike = str | pd.Timedelta | np.timedelta64
# Interface for loading Total Solar Irradiance (TSI) data.
# Returns a xa.DataArray containing yearly average TSI values with a `time`
# coordinate in units of years since 0000-1-1. E.g. 2023.5 corresponds to
# the middle of the year 2023.
TsiDataLoader = Callable[[], xa.DataArray]
# Total Solar Irradiance (TSI): Energy input to the top of the Earth's
# atmosphere in W⋅m⁻². TSI varies with time. This is the reference TSI value
# that can be used when more accurate data is not available.
# https://www.ncei.noaa.gov/products/climate-data-records/total-solar-irradiance
# https://github.com/ecmwf-ifs/ecrad/blob/6db82f929fb75028cc20606a04da87c0abe9b642/radiation/radiation_ecckd.F90#L296
_REFERENCE_TSI = 1361.0
def reference_tsi_data() -> xa.DataArray:
"""A TsiDataProvider that returns a single reference TSI value."""
return xa.DataArray(
np.array([_REFERENCE_TSI]),
dims=["time"],
coords={"time": np.array([0.0])},
)
def era5_tsi_data() -> xa.DataArray:
"""A TsiDataProvider that returns ERA5 compatible TSI data."""
# ECMWF provided the data used for ERA5, which was hardcoded in the IFS (cycle
# 41r2). The values were scaled down to agree better with more recent
# observations of the sun.
time = np.arange(1951.5, 2035.5, 1.0)
tsi = 0.9965 * np.array([
# fmt: off
# 1951-1995 (non-repeating sequence)
1365.7765, 1365.7676, 1365.6284, 1365.6564, 1365.7773,
1366.3109, 1366.6681, 1366.6328, 1366.3828, 1366.2767,
1365.9199, 1365.7484, 1365.6963, 1365.6976, 1365.7341,
1365.9178, 1366.1143, 1366.1644, 1366.2476, 1366.2426,
1365.9580, 1366.0525, 1365.7991, 1365.7271, 1365.5345,
1365.6453, 1365.8331, 1366.2747, 1366.6348, 1366.6482,
1366.6951, 1366.2859, 1366.1992, 1365.8103, 1365.6416,
1365.6379, 1365.7899, 1366.0826, 1366.6479, 1366.5533,
1366.4457, 1366.3021, 1366.0286, 1365.7971, 1365.6996,
# 1996-2008 (13 year cycle, repeated below)
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
1365.8107, 1365.7240, 1365.6918,
# 2009-2021
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
1365.8107, 1365.7240, 1365.6918,
# 2022-2034
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
1365.8107, 1365.7240, 1365.6918,
# fmt: on
])
return xa.DataArray(tsi, dims=["time"], coords={"time": time})
# HRES compatible TSI data is from IFS cycle 47r1. The dataset can be obtained
# from the ECRAD package: https://confluence.ecmwf.int/display/ECRAD.
# The example code below can load this dataset from a local file.
# def hres_tsi_data() -> xa.DataArray:
# with open("total_solar_irradiance_CMIP6_47r1.nc", "rb") as f:
# with xa.load_dataset(f, decode_times=False) as ds:
# return ds["tsi"]
_DEFAULT_TSI_DATA_LOADER: TsiDataLoader = era5_tsi_data
def get_tsi(
timestamps: Sequence[_TimestampLike], tsi_data: xa.DataArray
) -> chex.Array:
"""Returns TSI values for the given timestamps.
TSI values are interpolated from the provided yearly TSI data.
Args:
timestamps: Timestamps for which to compute TSI values.
tsi_data: A DataArray with a single dimension `time` that has coordinates in
units of years since 0000-1-1. E.g. 2023.5 corresponds to the middle of
the year 2023.
Returns:
An Array containing interpolated TSI data.
"""
timestamps = pd.DatetimeIndex(timestamps)
timestamps_date = pd.DatetimeIndex(timestamps.date)
day_fraction = (timestamps - timestamps_date) / pd.Timedelta(days=1)
year_length = 365 + timestamps.is_leap_year
year_fraction = (timestamps.dayofyear - 1 + day_fraction) / year_length
fractional_year = timestamps.year + year_fraction
return np.interp(fractional_year, tsi_data.coords["time"].data, tsi_data.data)
@dataclasses.dataclass(frozen=True)
class _OrbitalParameters:
"""Parameters characterising Earth's position relative to the Sun.
The parameters characterize the position of the Earth in its orbit around the
Sun for specific points in time. Each attribute is an N-dimensional array
to represent orbital parameters for multiple points in time.
Attributes:
theta: The number of Julian years since the Julian epoch J2000.0.
rotational_phase: The phase of the Earth's rotation along its axis as a
ratio with 0 representing the phase at Julian epoch J2000.0 at exactly
12:00 Terrestrial Time (TT). Multiplying this value by `2*pi` yields the
phase in radians.
sin_declination: Sine of the declination of the Sun as seen from the Earth.
cos_declination: Cosine of the declination of the Sun as seen from the
Earth.
eq_of_time_seconds: The value of the equation of time, in seconds.
solar_distance_au: Earth-Sun distance in astronomical units.
"""
theta: chex.Array
rotational_phase: chex.Array
sin_declination: chex.Array
cos_declination: chex.Array
eq_of_time_seconds: chex.Array
solar_distance_au: chex.Array
def _get_j2000_days(timestamp: pd.Timestamp) -> float:
"""Returns the number of days since the J2000 epoch.
Args:
timestamp: A timestamp for which to compute the J2000 days.
Returns:
The J2000 days corresponding to the input timestamp.
"""
return timestamp.to_julian_date() - _J2000_EPOCH
def _get_orbital_parameters(j2000_days: chex.Array) -> _OrbitalParameters:
"""Computes the orbital parameters for the given J2000 days.
Args:
j2000_days: Timestamps represented as the number of days since the J2000
epoch.
Returns:
Orbital parameters for the given timestamps. Each attribute of the return
value is an array containing the same dimensions as the input.
"""
# Orbital parameters are computed based on the formulas in this code, which
# were determined empirically to produce radiation values similar to ERA5:
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/sucst.F90
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fctast.cdk
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fcttim.cdk
# There are many variations to these formulas, but since the goal is to match
# the values in ERA5, the formulas were implemented as is. Comments reference
# the notation used in those sources. Here are some additional references
# related to the quantities being computed here:
# https://aa.usno.navy.mil/faq/sun_approx
# https://en.wikipedia.org/wiki/Position_of_the_Sun
# https://en.wikipedia.org/wiki/Equation_of_time
# Number of Julian years since the J2000 epoch (including fractional years).
theta = j2000_days / _JULIAN_YEAR_LENGTH_IN_DAYS
# The phase of the Earth's rotation along its axis as a ratio. 0 represents
# Julian epoch J2000.0 at exactly 12:00 Terrestrial Time (TT).
rotational_phase = j2000_days % 1.0
# REL(PTETA).
rel = 1.7535 + 6.283076 * theta
# REM(PTETA).
rem = 6.240041 + 6.283020 * theta
# RLLS(PTETA).
rlls = 4.8951 + 6.283076 * theta
# Variables used in the three polynomials below.
one = jnp.ones_like(theta)
sin_rel = jnp.sin(rel)
cos_rel = jnp.cos(rel)
sin_two_rel = jnp.sin(2.0 * rel)
cos_two_rel = jnp.cos(2.0 * rel)
sin_two_rlls = jnp.sin(2.0 * rlls)
cos_two_rlls = jnp.cos(2.0 * rlls)
sin_four_rlls = jnp.sin(4.0 * rlls)
sin_rem = jnp.sin(rem)
sin_two_rem = jnp.sin(2.0 * rem)
# Ecliptic longitude of the Sun - RLLLS(PTETA).
rllls = jnp.dot(
jnp.stack(
[one, theta, sin_rel, cos_rel, sin_two_rel, cos_two_rel], axis=-1
),
jnp.array([4.8952, 6.283320, -0.0075, -0.0326, -0.0003, 0.0002]),
)
# Angle in radians between the Earth's rotational axis and its orbital axis.
# Equivalent to 23.4393°.
repsm = 0.409093
# Declination of the Sun - RDS(teta).
sin_declination = jnp.sin(repsm) * jnp.sin(rllls)
cos_declination = jnp.sqrt(1.0 - sin_declination**2)
# Equation of time in seconds - RET(PTETA).
eq_of_time_seconds = jnp.dot(
jnp.stack(
[
sin_two_rlls,
sin_rem,
sin_rem * cos_two_rlls,
sin_four_rlls,
sin_two_rem,
],
axis=-1,
),
jnp.array([591.8, -459.4, 39.5, -12.7, -4.8]),
)
# Earth-Sun distance in astronomical units - RRS(PTETA).
solar_distance_au = jnp.dot(
jnp.stack([one, sin_rel, cos_rel], axis=-1),
jnp.array([1.0001, -0.0163, 0.0037]),
)
return _OrbitalParameters(
theta=theta,
rotational_phase=rotational_phase,
sin_declination=sin_declination,
cos_declination=cos_declination,
eq_of_time_seconds=eq_of_time_seconds,
solar_distance_au=solar_distance_au,
)
def _get_solar_sin_altitude(
op: _OrbitalParameters,
sin_latitude: chex.Array,
cos_latitude: chex.Array,
longitude: chex.Array,
) -> chex.Array:
"""Returns the sine of the solar altitude angle.
All computations are vectorized. Dimensions of all the inputs should be
broadcastable using standard NumPy rules. For example, if `op` has shape
`(T, 1, 1)`, `latitude` has shape `(1, H, 1)`, and `longitude` has shape
`(1, H, W)`, the return value will have shape `(T, H, W)`.
Args:
op: Orbital parameters characterising Earth's position relative to the Sun.
sin_latitude: Sine of latitude coordinates.
cos_latitude: Cosine of latitude coordinates.
longitude: Longitude coordinates in radians.
Returns:
Sine of the solar altitude angle for each set of orbital parameters and each
geographical coordinates. The returned array has the shape resulting from
broadcasting all the inputs together.
"""
solar_time = op.rotational_phase + op.eq_of_time_seconds / _SECONDS_PER_DAY
# https://en.wikipedia.org/wiki/Hour_angle#Solar_hour_angle
hour_angle = 2.0 * jnp.pi * solar_time + longitude
# https://en.wikipedia.org/wiki/Solar_zenith_angle
sin_altitude = (
cos_latitude * op.cos_declination * jnp.cos(hour_angle)
+ sin_latitude * op.sin_declination
)
return sin_altitude
def _get_radiation_flux(
j2000_days: chex.Array,
sin_latitude: chex.Array,
cos_latitude: chex.Array,
longitude: chex.Array,
tsi: chex.Array,
) -> chex.Array:
"""Computes the instantaneous TOA incident solar radiation flux.
Computes the instantanous Top-Of-the-Atmosphere (TOA) incident radiation flux
in W⋅m⁻² for the given timestamps and locations on the surface of the Earth.
See https://en.wikipedia.org/wiki/Solar_irradiance.
All inputs are assumed to be broadcastable together using standard NumPy
rules.
Args:
j2000_days: Timestamps represented as the number of days since the J2000
epoch.
sin_latitude: Sine of latitude coordinates.
cos_latitude: Cosine of latitude coordinates.
longitude: Longitude coordinates in radians.
tsi: Total Solar Irradiance (TSI) in W⋅m⁻². This can be a scalar (default)
to use the same TSI value for all the inputs, or an array to allow TSI to
depend on the timestamps.
Returns:
The instataneous TOA incident solar radiation flux in W⋅m⁻² for the given
timestamps and geographical coordinates. The returned array has the shape
resulting from broadcasting all the inputs together.
"""
op = _get_orbital_parameters(j2000_days)
# Attenuation of the solar radiation based on the solar distance.
solar_factor = (1.0 / op.solar_distance_au) ** 2
sin_altitude = _get_solar_sin_altitude(
op, sin_latitude, cos_latitude, longitude
)
return tsi * solar_factor * jnp.maximum(sin_altitude, 0.0)
def _get_integrated_radiation(
j2000_days: chex.Array,
sin_latitude: chex.Array,
cos_latitude: chex.Array,
longitude: chex.Array,
tsi: chex.Array,
integration_period: pd.Timedelta,
num_integration_bins: int,
) -> chex.Array:
"""Returns the TOA solar radiation flux integrated over a time period.
Integrates the instantaneous TOA solar radiation flux over a time period.
The input timestamps represent the end times of each integration period.
When the integration period is one hour this approximates the
`toa_incident_solar_radiation` (or `tisr`) parameter from the ERA5 dataset.
See https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
https://codes.ecmwf.int/grib/param-db/?id=212.
All inputs are assumed to be broadcastable together using standard NumPy
rules. To approximate the integral, the instantaneous radiation is computed
at `num_integration_bins+1` time steps using `_get_radiation_flux` and
integrated using the trapezoidal rule. A dimension is appended at the end
of all inputs to compute the instantaneous radiation, which is then integrated
over to compute the final result.
Args:
j2000_days: Timestamps represented as the number of days since the J2000
epoch. These correspond to the end times of each integration period.
sin_latitude: Sine of latitude coordinates.
cos_latitude: Cosine of latitude coordinates.
longitude: Longitude in radians.
tsi: Total Solar Irradiance (TSI) in W⋅m⁻².
integration_period: Integration period.
num_integration_bins: Number of bins to divide the `integration_period` to
approximate the integral using the trapezoidal rule.
Returns:
The TOA solar radiation flux integrated over the requested time period for
the given timestamps and geographical coordinates. Unit is J⋅m⁻² .
"""
# Offsets for the integration time steps.
offsets = (
pd.timedelta_range(
start=-integration_period,
end=pd.Timedelta(0),
periods=num_integration_bins + 1,
)
/ pd.Timedelta(days=1)
).to_numpy()
# Integration happens over the time dimension. Compute the instantaneous
# radiation flux for all the integration time steps by appending a dimension
# to all the inputs and adding `offsets` to `j2000_days` (will be broadcast
# over all the other dimensions).
fluxes = _get_radiation_flux(
j2000_days=jnp.expand_dims(j2000_days, axis=-1) + offsets,
sin_latitude=jnp.expand_dims(sin_latitude, axis=-1),
cos_latitude=jnp.expand_dims(cos_latitude, axis=-1),
longitude=jnp.expand_dims(longitude, axis=-1),
tsi=jnp.expand_dims(tsi, axis=-1),
)
# Size of each bin in seconds. The instantaneous solar radiation flux is
# returned in units of W⋅m⁻². Integrating over time expressed in seconds
# yields a result in units of J⋅m⁻².
dx = (integration_period / num_integration_bins) / pd.Timedelta(seconds=1)
return jax.scipy.integrate.trapezoid(fluxes, dx=dx)
_get_integrated_radiation_jitted = jax.jit(
_get_integrated_radiation,
static_argnames=["integration_period", "num_integration_bins"],
)
def get_toa_incident_solar_radiation(
timestamps: Sequence[_TimestampLike],
latitude: chex.Array,
longitude: chex.Array,
tsi_data: xa.DataArray | None = None,
integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
use_jit: bool = False,
) -> chex.Array:
"""Computes the solar radiation incident at the top of the atmosphere.
The solar radiation is computed for each element in `timestamps` for all the
locations on the grid determined by the `latitude` and `longitude` parameters.
To approximate the `toa_incident_solar_radiation` (or `tisr`) parameter from
the ERA5 dataset, set `integration_period` to one hour (default). See
https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
https://codes.ecmwf.int/grib/param-db/?id=212.
Args:
timestamps: Timestamps for which to compute the solar radiation.
latitude: The latitude coordinates in degrees of the grid for which to
compute the solar radiation.
longitude: The longitude coordinates in degrees of the grid for which to
compute the solar radiation.
tsi_data: A DataArray containing yearly TSI data as returned by a
`TsiDataLoader`. The default is to use ERA5 compatible TSI data.
integration_period: Timedelta to use to integrate the radiation, e.g. if
producing radiation for 1989-11-08 21:00:00, and `integration_period` is
"1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
21:00:00. The default value ("1h") matches ERA5.
num_integration_bins: Number of equally spaced bins to divide the
`integration_period` in when approximating the integral using the
trapezoidal rule. Performance and peak memory usage are affected by this
value. The default (360) provides a good approximation, but lower values
may work to improve performance and reduce memory usage.
use_jit: Set to True to use the jitted implementation, or False (default) to
use the non-jitted one.
Returns:
An 3D array with dimensions (time, lat, lon) containing the total
top of atmosphere solar radiation integrated for the `integration_period`
up to each timestamp.
"""
# Add a trailing dimension to latitude to get dimensions (lat, lon).
lat = jnp.radians(latitude).reshape((-1, 1))
lon = jnp.radians(longitude)
sin_lat = jnp.sin(lat)
cos_lat = jnp.cos(lat)
integration_period = pd.Timedelta(integration_period)
if tsi_data is None:
tsi_data = _DEFAULT_TSI_DATA_LOADER()
tsi = get_tsi(timestamps, tsi_data)
fn = (
_get_integrated_radiation_jitted if use_jit else _get_integrated_radiation
)
# Compute integral for each timestamp individually. Although this could be
# done in one step, peak memory usage would be proportional to
# `len(timestamps) * num_integration_bins`. Computing each timestamp
# individually reduces this to `max(len(timestamps), num_integration_bins)`.
# E.g. memory usage for a single timestamp, with a full 0.25° grid and 360
# integration bins is about 1.5 GB (1440 * 721 * 361 * 4 bytes); computing
# forcings for 40 prediction steps would require 60 GB.
results = []
for idx, timestamp in enumerate(timestamps):
results.append(
fn(
j2000_days=jnp.array(_get_j2000_days(pd.Timestamp(timestamp))),
sin_latitude=sin_lat,
cos_latitude=cos_lat,
longitude=lon,
tsi=tsi[idx],
integration_period=integration_period,
num_integration_bins=num_integration_bins,
)
)
return jnp.stack(results, axis=0)
def get_toa_incident_solar_radiation_for_xarray(
data_array_like: xa.DataArray | xa.Dataset,
tsi_data: xa.DataArray | None = None,
integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
use_jit: bool = False,
) -> xa.DataArray:
"""Computes the solar radiation incident at the top of the atmosphere.
This method is a wrapper for `get_toa_incident_solar_radiation` using
coordinates from an Xarray and returning an Xarray.
Args:
data_array_like: A xa.Dataset or xa.DataArray from which to take the time
and spatial coordinates for which to compute the solar radiation. It must
contain `lat` and `lon` spatial dimensions with corresponding coordinates.
If a `time` dimension is present, the `datetime` coordinate should be a
vector associated with that dimension containing timestamps for which to
compute the solar radiation. Otherwise, the `datetime` coordinate should
be a scalar representing the timestamp for which to compute the solar
radiation.
tsi_data: A DataArray containing yearly TSI data as returned by a
`TsiDataLoader`. The default is to use ERA5 compatible TSI data.
integration_period: Timedelta to use to integrate the radiation, e.g. if
producing radiation for 1989-11-08 21:00:00, and `integration_period` is
"1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
21:00:00. The default value ("1h") matches ERA5.
num_integration_bins: Number of equally spaced bins to divide the
`integration_period` in when approximating the integral using the
trapezoidal rule. Performance and peak memory usage are affected by this
value. The default (360) provides a good approximation, but lower values
may work to improve performance and reduce memory usage.
use_jit: Set to True to use the jitted implementation, or False to use the
non-jitted one.
Returns:
xa.DataArray with dimensions `(time, lat, lon)` if `data_array_like` had
a `time` dimension; or dimensions `(lat, lon)` otherwise. The `datetime`
coordinates and those for the dimensions are copied to the returned array.
The array contains the total top of atmosphere solar radiation integrated
for `integration_period` up to the corresponding `datetime`.
Raises:
ValueError: If there are missing coordinates or dimensions.
"""
missing_dims = set(["lat", "lon"]) - set(data_array_like.dims)
if missing_dims:
raise ValueError(
f"'{missing_dims}' dimensions are missing in `data_array_like`."
)
missing_coords = set(["datetime", "lat", "lon"]) - set(data_array_like.coords)
if missing_coords:
raise ValueError(
f"'{missing_coords}' coordinates are missing in `data_array_like`."
)
if "time" in data_array_like.dims:
timestamps = data_array_like.coords["datetime"].data
else:
timestamps = [data_array_like.coords["datetime"].data.item()]
radiation = get_toa_incident_solar_radiation(
timestamps=timestamps,
latitude=data_array_like.coords["lat"].data,
longitude=data_array_like.coords["lon"].data,
tsi_data=tsi_data,
integration_period=integration_period,
num_integration_bins=num_integration_bins,
use_jit=use_jit,
)
if "time" in data_array_like.dims:
output = xa.DataArray(radiation, dims=("time", "lat", "lon"))
else:
output = xa.DataArray(radiation[0], dims=("lat", "lon"))
# Preserve as many of the original coordinates as possible, so long as the
# dimension or the coordinate still exist in the output array.
for k, coord in data_array_like.coords.items():
if set(coord.dims).issubset(set(output.dims)):
output.coords[k] = coord
return output
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import timeit
from typing import Sequence
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from graphcast import solar_radiation
import numpy as np
import pandas as pd
import xarray as xa
def _get_grid_lat_lon_coords(
num_lat: int, num_lon: int
) -> tuple[np.ndarray, np.ndarray]:
"""Generates a linear latitude-longitude grid of the given size.
Args:
num_lat: Size of the latitude dimension of the grid.
num_lon: Size of the longitude dimension of the grid.
Returns:
A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude
coordinates in degrees of the generated grid.
"""
lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True)
lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False)
return lat, lon
class SolarRadiationTest(parameterized.TestCase):
def setUp(self):
super().setUp()
np.random.seed(0)
def test_missing_dim_raises_value_error(self):
data = xa.DataArray(
np.random.randn(2, 2),
coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])],
dims=["lon", "x"],
)
with self.assertRaisesRegex(
ValueError, r".* dimensions are missing in `data_array_like`."
):
solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data, integration_period="1h", num_integration_bins=360
)
def test_missing_coordinate_raises_value_error(self):
data = xa.Dataset(
data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))},
coords={
"lat": np.array([0.0, 0.1, 0.2]),
"lon": np.array([0.0, 0.5]),
},
)
with self.assertRaisesRegex(
ValueError, r".* coordinates are missing in `data_array_like`."
):
solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data, integration_period="1h", num_integration_bins=360
)
def test_shape_multiple_timestamps(self):
data = xa.Dataset(
data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))},
coords={
"lat": np.array([0.0, 0.1, 0.2, 0.3]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
"time", np.array([10, 20], dtype="datetime64[D]")
),
},
)
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data, integration_period="1h", num_integration_bins=2
)
self.assertEqual(("time", "lat", "lon"), actual.dims)
self.assertEqual((2, 4, 2), actual.shape)
def test_shape_single_timestamp(self):
data = xa.Dataset(
data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))},
coords={
"lat": np.array([0.0, 0.1, 0.2, 0.3]),
"lon": np.array([0.0, 0.5]),
"datetime": np.datetime64(10, "D"),
},
)
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data, integration_period="1h", num_integration_bins=2
)
self.assertEqual(("lat", "lon"), actual.dims)
self.assertEqual((4, 2), actual.shape)
@parameterized.named_parameters(
dict(
testcase_name="one_timestamp_jitted",
periods=1,
repeats=3,
use_jit=True,
),
dict(
testcase_name="one_timestamp_non_jitted",
periods=1,
repeats=3,
use_jit=False,
),
dict(
testcase_name="ten_timestamps_non_jitted",
periods=10,
repeats=1,
use_jit=False,
),
)
def test_full_spatial_resolution(
self, periods: int, repeats: int, use_jit: bool
):
timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h")
# Generate a linear grid with 0.25 degrees resolution similar to ERA5.
lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440)
def benchmark() -> None:
solar_radiation.get_toa_incident_solar_radiation(
timestamps,
lat,
lon,
integration_period="1h",
num_integration_bins=360,
use_jit=use_jit,
).block_until_ready()
results = timeit.repeat(benchmark, repeat=repeats, number=1)
logging.info(
"Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s",
len(timestamps),
len(lat),
len(lon),
np.array2string(np.array(results), precision=1),
)
class GetTsiTest(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name="reference_tsi_data",
loader=solar_radiation.reference_tsi_data,
expected_tsi=np.array([1361.0]),
),
dict(
testcase_name="era5_tsi_data",
loader=solar_radiation.era5_tsi_data,
expected_tsi=np.array([1360.9440]), # 0.9965 * 1365.7240
),
)
def test_mid_2020_lookup(
self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray
):
tsi_data = loader()
tsi = solar_radiation.get_tsi(
[np.datetime64("2020-07-02T00:00:00")], tsi_data
)
np.testing.assert_allclose(expected_tsi, tsi)
@parameterized.named_parameters(
dict(
testcase_name="beginning_2020_left_boundary",
timestamps=[np.datetime64("2020-01-01T00:00:00")],
expected_tsi=np.array([1000.0]),
),
dict(
testcase_name="mid_2020_exact",
timestamps=[np.datetime64("2020-07-02T00:00:00")],
expected_tsi=np.array([1000.0]),
),
dict(
testcase_name="beginning_2021_interpolated",
timestamps=[np.datetime64("2021-01-01T00:00:00")],
expected_tsi=np.array([1150.0]),
),
dict(
testcase_name="mid_2021_lookup",
timestamps=[np.datetime64("2021-07-02T12:00:00")],
expected_tsi=np.array([1300.0]),
),
dict(
testcase_name="beginning_2022_interpolated",
timestamps=[np.datetime64("2022-01-01T00:00:00")],
expected_tsi=np.array([1250.0]),
),
dict(
testcase_name="mid_2022_lookup",
timestamps=[np.datetime64("2022-07-02T12:00:00")],
expected_tsi=np.array([1200.0]),
),
dict(
testcase_name="beginning_2023_right_boundary",
timestamps=[np.datetime64("2023-01-01T00:00:00")],
expected_tsi=np.array([1200.0]),
),
)
def test_interpolation(
self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray
):
tsi_data = xa.DataArray(
np.array([1000.0, 1300.0, 1200.0]),
dims=["time"],
coords={"time": np.array([2020.5, 2021.5, 2022.5])},
)
tsi = solar_radiation.get_tsi(timestamps, tsi_data)
np.testing.assert_allclose(expected_tsi, tsi)
if __name__ == "__main__":
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data-structure for storing graphs with typed edges and nodes."""
from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
_T = TypeVar('_T')
# All tensors have a "flat_batch_axis", which is similar to the leading
# axes of graph_tuples:
# * In the case of nodes this is simply a shared node and flat batch axis, with
# size corresponding to the total number of nodes in the flattened batch.
# * In the case of edges this is simply a shared edge and flat batch axis, with
# size corresponding to the total number of edges in the flattened batch.
# * In the case of globals this is simply the number of graphs in the flattened
# batch.
# All shapes may also have any additional leading shape "batch_shape".
# Options for building batches are:
# * Use a provided "flatten" method that takes a leading `batch_shape` and
# it into the flat_batch_axis (this will be useful when using `tf.Dataset`
# which supports batching into RaggedTensors, with leading batch shape even
# if graphs have different numbers of nodes and edges), so the RaggedBatches
# can then be converted into something without ragged dimensions that jax can
# use.
# * Directly build a "flat batch" using a provided function for batching a list
# of graphs (how it is done in `jraph`).
class NodeSet(NamedTuple):
"""Represents a set of nodes."""
n_node: ArrayLike # [num_flat_graphs]
features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
class EdgesIndices(NamedTuple):
"""Represents indices to nodes adjacent to the edges."""
senders: ArrayLike # [num_flat_edges]
receivers: ArrayLike # [num_flat_edges]
class EdgeSet(NamedTuple):
"""Represents a set of edges."""
n_edge: ArrayLike # [num_flat_graphs]
indices: EdgesIndices
features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
class Context(NamedTuple):
# `n_graph` always contains ones but it is useful to query the leading shape
# in case of graphs without any nodes or edges sets.
n_graph: ArrayLike # [num_flat_graphs]
features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
class EdgeSetKey(NamedTuple):
name: str # Name of the EdgeSet.
# Sender node set name and receiver node set name connected by the edge set.
node_sets: Tuple[str, str]
class TypedGraph(NamedTuple):
"""A graph with typed nodes and edges.
A typed graph is made of a context, multiple sets of nodes and multiple
sets of edges connecting those nodes (as indicated by the EdgeSetKey).
"""
context: Context
nodes: Mapping[str, NodeSet]
edges: Mapping[EdgeSetKey, EdgeSet]
def edge_key_by_name(self, name: str) -> EdgeSetKey:
found_key = [k for k in self.edges.keys() if k.name == name]
if len(found_key) != 1:
raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
name, ', '.join(x.name for x in self.edges.keys())))
return found_key[0]
def edge_by_name(self, name: str) -> EdgeSet:
return self.edges[self.edge_key_by_name(name)]
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library of typed Graph Neural Networks."""
from typing import Callable, Mapping, Optional, Union
from graphcast import typed_graph
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
# All features will be an ArrayTree.
NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = (
jraph.ArrayTree)
# Signature:
# (node features, outgoing edge features, incoming edge features,
# globals) -> updated node features
GNUpdateNodeFn = Callable[
[NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures],
Globals],
NodeFeatures]
GNUpdateGlobalFn = Callable[
[Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals],
Globals]
def GraphNetwork( # pylint: disable=invalid-name
update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn],
update_node_fn: Mapping[str, GNUpdateNodeFn],
update_global_fn: Optional[GNUpdateGlobalFn] = None,
aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
.segment_sum,
aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph
.segment_sum,
aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph
.segment_sum,
):
"""Returns a method that applies a configured GraphNetwork.
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
extended to Typed Graphs with multiple edge sets and node sets and extended to
allow aggregating not only edges received by the nodes, but also edges sent by
the nodes.
Example usage::
gn = GraphNetwork(update_edge_function,
update_node_function, **kwargs)
# Conduct multiple rounds of message passing with the same parameters:
for _ in range(num_message_passing_steps):
graph = gn(graph)
Args:
update_edge_fn: mapping of functions used to update a subset of the edge
types, indexed by edge type name.
update_node_fn: mapping of functions used to update a subset of the node
types, indexed by node type name.
update_global_fn: function used to update the globals or None to deactivate
globals updates.
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
node.
aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
globals.
aggregate_edges_for_globals_fn: function used to aggregate the edges for the
globals.
Returns:
A method that applies the configured GraphNetwork.
"""
def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Applies a configured GraphNetwork to a graph.
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
extended to Typed Graphs with multiple edge sets and node sets and extended
to allow aggregating not only edges received by the nodes, but also edges
sent by the nodes.
Args:
graph: a `TypedGraph` containing the graph.
Returns:
Updated `TypedGraph`.
"""
updated_graph = graph
# Edge update.
updated_edges = dict(updated_graph.edges)
for edge_set_name, edge_fn in update_edge_fn.items():
edge_set_key = graph.edge_key_by_name(edge_set_name)
updated_edges[edge_set_key] = _edge_update(
updated_graph, edge_fn, edge_set_key)
updated_graph = updated_graph._replace(edges=updated_edges)
# Node update.
updated_nodes = dict(updated_graph.nodes)
for node_set_key, node_fn in update_node_fn.items():
updated_nodes[node_set_key] = _node_update(
updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
updated_graph = updated_graph._replace(nodes=updated_nodes)
# Global update.
if update_global_fn:
updated_context = _global_update(
updated_graph, update_global_fn,
aggregate_edges_for_globals_fn,
aggregate_nodes_for_globals_fn)
updated_graph = updated_graph._replace(context=updated_context)
return updated_graph
return _apply_graph_net
def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name
"""Updates an edge set of a given key."""
sender_nodes = graph.nodes[edge_set_key.node_sets[0]]
receiver_nodes = graph.nodes[edge_set_key.node_sets[1]]
edge_set = graph.edges[edge_set_key]
senders = edge_set.indices.senders # pytype: disable=attribute-error
receivers = edge_set.indices.receivers # pytype: disable=attribute-error
sent_attributes = tree.tree_map(
lambda n: n[senders], sender_nodes.features)
received_attributes = tree.tree_map(
lambda n: n[receivers], receiver_nodes.features)
n_edge = edge_set.n_edge
sum_n_edge = senders.shape[0]
global_features = tree.tree_map(
lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
graph.context.features)
new_features = edge_fn(
edge_set.features, sent_attributes, received_attributes,
global_features)
return edge_set._replace(features=new_features)
def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name
"""Updates an edge set of a given key."""
node_set = graph.nodes[node_set_key]
sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
sent_features = {}
for edge_set_key, edge_set in graph.edges.items():
sender_node_set_key = edge_set_key.node_sets[0]
if sender_node_set_key == node_set_key:
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
senders = edge_set.indices.senders
sent_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
received_features = {}
for edge_set_key, edge_set in graph.edges.items():
receiver_node_set_key = edge_set_key.node_sets[1]
if receiver_node_set_key == node_set_key:
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
receivers = edge_set.indices.receivers
received_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
n_node = node_set.n_node
global_features = tree.tree_map(
lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
graph.context.features)
new_features = node_fn(
node_set.features, sent_features, received_features, global_features)
return node_set._replace(features=new_features)
def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name
"""Updates an edge set of a given key."""
n_graph = graph.context.n_graph.shape[0]
graph_idx = jnp.arange(n_graph)
edge_features = {}
for edge_set_key, edge_set in graph.edges.items():
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
sum_n_edge = edge_set.indices.senders.shape[0]
edge_gr_idx = jnp.repeat(
graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge)
edge_features[edge_set_key.name] = tree.tree_map(
lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
edge_set.features)
node_features = {}
for node_set_key, node_set in graph.nodes.items():
sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
node_gr_idx = jnp.repeat(
graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node)
node_features[node_set_key] = tree.tree_map(
lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
node_set.features)
new_features = global_fn(node_features, edge_features, graph.context.features)
return graph.context._replace(features=new_features)
InteractionUpdateNodeFn = Callable[
[jraph.NodeFeatures,
Mapping[str, SenderFeatures],
Mapping[str, ReceiverFeatures]],
jraph.NodeFeatures]
InteractionUpdateNodeFnNoSentEdges = Callable[
[jraph.NodeFeatures,
Mapping[str, ReceiverFeatures]],
jraph.NodeFeatures]
def InteractionNetwork( # pylint: disable=invalid-name
update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn],
update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn,
InteractionUpdateNodeFnNoSentEdges]],
aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
.segment_sum,
include_sent_messages_in_node_update: bool = False):
"""Returns a method that applies a configured InteractionNetwork.
An interaction network computes interactions on the edges based on the
previous edges features, and on the features of the nodes sending into those
edges. It then updates the nodes based on the incoming updated edges.
See https://arxiv.org/abs/1612.00222 for more details.
This implementation extends the behavior to `TypedGraphs` adding an option
to include edge features for which a node is a sender in the arguments to
the node update function.
Args:
update_edge_fn: mapping of functions used to update a subset of the edge
types, indexed by edge type name.
update_node_fn: mapping of functions used to update a subset of the node
types, indexed by node type name.
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
node.
include_sent_messages_in_node_update: pass edge features for which a node is
a sender to the node update function.
"""
# An InteractionNetwork is a GraphNetwork without globals features,
# so we implement the InteractionNetwork as a configured GraphNetwork.
# An InteractionNetwork edge function does not have global feature inputs,
# so we filter the passed global argument in the GraphNetwork.
wrapped_update_edge_fn = tree.tree_map(
lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn)
# Similarly, we wrap the update_node_fn to ensure only the expected
# arguments are passed to the Interaction net.
if include_sent_messages_in_node_update:
wrapped_update_node_fn = tree.tree_map(
lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn)
else:
wrapped_update_node_fn = tree.tree_map(
lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn)
return GraphNetwork(
update_edge_fn=wrapped_update_edge_fn,
update_node_fn=wrapped_update_node_fn,
aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
def GraphMapFeatures( # pylint: disable=invalid-name
embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None,
embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None,
embed_global_fn: Optional[jraph.EmbedGlobalFn] = None):
"""Returns function which embeds the components of a graph independently.
Args:
embed_edge_fn: mapping of functions used to embed each edge type,
indexed by edge type name.
embed_node_fn: mapping of functions used to embed each node type,
indexed by node type name.
embed_global_fn: function used to embed the globals.
"""
def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
updated_edges = dict(graph.edges)
if embed_edge_fn:
for edge_set_name, embed_fn in embed_edge_fn.items():
edge_set_key = graph.edge_key_by_name(edge_set_name)
edge_set = graph.edges[edge_set_key]
updated_edges[edge_set_key] = edge_set._replace(
features=embed_fn(edge_set.features))
updated_nodes = dict(graph.nodes)
if embed_node_fn:
for node_set_key, embed_fn in embed_node_fn.items():
node_set = graph.nodes[node_set_key]
updated_nodes[node_set_key] = node_set._replace(
features=embed_fn(node_set.features))
updated_context = graph.context
if embed_global_fn:
updated_context = updated_context._replace(
features=embed_global_fn(updated_context.features))
return graph._replace(edges=updated_edges, nodes=updated_nodes,
context=updated_context)
return _embed
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers to use xarray.{Variable,DataArray,Dataset} with JAX.
Allows them to be based on JAX arrays without converting to numpy arrays under
the hood, so you can start with a JAX array, do some computation with it in
xarray-land, get a JAX array out the other end and (for example) jax.jit
through the whole thing. You can even jax.jit a function which accepts and
returns xarray.Dataset, DataArray and Variable.
## Creating xarray datatypes from jax arrays, and vice-versa.
You can use the xarray_jax.{Variable,DataArray,Dataset} constructors, which have
the same API as the standard xarray constructors but will accept JAX arrays
without converting them to numpy.
It does this by wrapping the JAX array in a wrapper before passing it to
xarray; you can also do this manually by calling xarray_jax.wrap on your JAX
arrays before passing them to the standard xarray constructors.
To get non-wrapped JAX arrays out the other end, you can use e.g.:
xarray_jax.jax_vars(dataset)
xarray_jax.jax_data(dataset.some_var)
which will complain if the data isn't actually a JAX array. Use this if you need
to make sure the computation has gone via JAX, e.g. if it's the output of code
that you want to JIT or compute gradients through. If this is not the case and
you want to support passing plain numpy arrays through as well as potentially
JAX arrays, you can use:
xarray_jax.unwrap_vars(dataset)
xarray_jax.unwrap_data(dataset.some_var)
which will unwrap the data if it is a wrapped JAX array, but otherwise pass
it through to you without complaint.
The wrapped JAX arrays aim to support all the core operations from the numpy
array API that xarray expects, however there may still be some gaps; if you run
into any problems around this, you may need to add a few more proxy methods onto
the wrapper class below.
In future once JAX and xarray support the new Python array API standard
(https://data-apis.org/array-api/latest/index.html), we hope to avoid the need
for wrapping the JAX arrays like this.
## jax.jit and pmap of functions taking and returning xarray datatypes
We register xarray datatypes with jax.tree_util, which allows them to be treated
as generic containers of jax arrays by various parts of jax including jax.jit.
This allows for, e.g.:
@jax.jit
def foo(input: xarray.Dataset) -> xarray.Dataset:
...
It will not work out-of-the-box with shape-modifying transformations like
jax.pmap, or e.g. a jax.tree_util.tree_map with some transform that alters array
shapes or dimension order. That's because we won't know what dimension names
and/or coordinates to use when unflattening, if the results have a different
shape to the data that was originally flattened.
You can work around this using xarray_jax.dims_change_on_unflatten, however,
and in the case of jax.pmap we provide a wrapper xarray_jax.pmap which allows
it to be used with functions taking and returning xarrays.
## Treatment of coordinates
We don't support passing jax arrays as coordinates when constructing a
DataArray/Dataset. This is because xarray's advanced indexing and slicing is
unlikely to work with jax arrays (at least when a Tracer is used during
jax.jit), and also because some important datatypes used for coordinates, like
timedelta64 and datetime64, are not supported by jax.
For the purposes of tree_util and jax.jit, coordinates are not treated as leaves
of the tree (array data 'contained' by a Dataset/DataArray), they are just a
static part of the structure. That means that if a jit'ed function is called
twice with Dataset inputs that use different coordinates, it will compile a
separate version of the function for each. The coordinates are treated like
static_argnums by jax.jit.
If you want to use dynamic data for coordinates, we recommend making it a
data_var instead of a coord. You won't be able to do indexing and slicing using
the coordinate, but that wasn't going to work with a jax array anyway.
"""
import collections
import contextlib
import contextvars
from typing import Any, Callable, Hashable, Iterator, Mapping, Optional, Union, Tuple, TypeVar, cast
import jax
import jax.numpy as jnp
import numpy as np
import tree
import xarray
def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name
"""Like xarray.Variable, but can wrap JAX arrays."""
return xarray.Variable(dims, wrap(data), **kwargs)
_JAX_COORD_ATTR_NAME = '_jax_coord'
def DataArray( # pylint:disable=invalid-name
data,
coords=None,
dims=None,
name=None,
attrs=None,
jax_coords=None,
) -> xarray.DataArray:
"""Like xarray.DataArray, but supports using JAX arrays.
Args:
data: As for xarray.DataArray, except jax arrays are also supported.
coords: Coordinates for the array, see xarray.DataArray. These coordinates
must be based on plain numpy arrays or something convertible to plain
numpy arrays. Their values will form a static part of the data structure
from the point of view of jax.tree_util. In particular this means these
coordinates will be passed as plain numpy arrays even inside a JIT'd
function, and the JIT'd function will be recompiled under the hood if the
coordinates of DataArrays passed into it change.
If this is not convenient for you, see also jax_coords below.
dims: See xarray.DataArray.
name: See xarray.DataArray.
attrs: See xarray.DataArray.
jax_coords: Additional coordinates, which *can* use JAX arrays. These
coordinates will be treated as JAX data from the point of view of
jax.tree_util, that means when JIT'ing they will be passed as tracers and
computation involving them will be JIT'd.
Unfortunately a side-effect of this is that they can't be used as index
coordinates (because xarray's indexing logic is not JIT-able). If you
specify a coordinate with the same name as a dimension here, it will not
be set as an index coordinate; this behaviour is different to the default
for `coords`, and it means that things like `.sel` based on the jax
coordinate will not work.
Note we require `jax_coords` to be explicitly specified via a different
constructor argument to `coords`, rather than just looking for jax arrays
within the `coords` and treating them differently. This is because it
affects the way jax.tree_util treats them, which is somewhat orthogonal to
whether the value is passed in as numpy or not, and generally needs to be
handled consistently so is something we encourage explicit control over.
Returns:
An instance of xarray.DataArray. Where JAX arrays are used as data or
coords, they will be wrapped with JaxArrayWrapper and can be unwrapped via
`unwrap` and `unwrap_data`.
"""
result = xarray.DataArray(
wrap(data), dims=dims, name=name, attrs=attrs or {})
return assign_coords(result, coords=coords, jax_coords=jax_coords)
def Dataset( # pylint:disable=invalid-name
data_vars,
coords=None,
attrs=None,
jax_coords=None,
) -> xarray.Dataset:
"""Like xarray.Dataset, but can wrap JAX arrays.
Args:
data_vars: As for xarray.Dataset, except jax arrays are also supported.
coords: Coordinates for the dataset, see xarray.Dataset. These coordinates
must be based on plain numpy arrays or something convertible to plain
numpy arrays. Their values will form a static part of the data structure
from the point of view of jax.tree_util. In particular this means these
coordinates will be passed as plain numpy arrays even inside a JIT'd
function, and the JIT'd function will be recompiled under the hood if the
coordinates of DataArrays passed into it change.
If this is not convenient for you, see also jax_coords below.
attrs: See xarray.Dataset.
jax_coords: Additional coordinates, which *can* use JAX arrays. These
coordinates will be treated as JAX data from the point of view of
jax.tree_util, that means when JIT'ing they will be passed as tracers and
computation involving them will be JIT'd.
Unfortunately a side-effect of this is that they can't be used as index
coordinates (because xarray's indexing logic is not JIT-able). If you
specify a coordinate with the same name as a dimension here, it will not
be set as an index coordinate; this behaviour is different to the default
for `coords`, and it means that things like `.sel` based on the jax
coordinate will not work.
Note we require `jax_coords` to be explicitly specified via a different
constructor argument to `coords`, rather than just looking for jax arrays
within the `coords` and treating them differently. This is because it
affects the way jax.tree_util treats them, which is somewhat orthogonal to
whether the value is passed in as numpy or not, and generally needs to be
handled consistently so is something we encourage explicit control over.
Returns:
An instance of xarray.Dataset. Where JAX arrays are used as data, they
will be wrapped with JaxArrayWrapper.
"""
wrapped_data_vars = {}
for name, var_like in data_vars.items():
# xarray.Dataset accepts a few different formats for data_vars:
if isinstance(var_like, jax.Array):
wrapped_data_vars[name] = wrap(var_like)
elif isinstance(var_like, tuple):
# Layout is (dims, data, ...). We wrap data.
wrapped_data_vars[name] = (var_like[0], wrap(var_like[1])) + var_like[2:]
else:
# Could be a plain numpy array or scalar (we don't wrap), or an
# xarray.Variable, DataArray etc, which we must assume is already wrapped
# if necessary (e.g. if creating using xarray_jax.{Variable,DataArray}).
wrapped_data_vars[name] = var_like
result = xarray.Dataset(
data_vars=wrapped_data_vars,
attrs=attrs)
return assign_coords(result, coords=coords, jax_coords=jax_coords)
DatasetOrDataArray = TypeVar(
'DatasetOrDataArray', xarray.Dataset, xarray.DataArray)
def assign_coords(
x: DatasetOrDataArray,
*,
coords: Optional[Mapping[Hashable, Any]] = None,
jax_coords: Optional[Mapping[Hashable, Any]] = None,
) -> DatasetOrDataArray:
"""Replacement for assign_coords which works in presence of jax_coords.
`jax_coords` allow certain specified coordinates to have their data passed as
JAX arrays (including through jax.jit boundaries). The compromise in return is
that they are not created as index coordinates and cannot be used for .sel
and other coordinate-based indexing operations. See docs for `jax_coords` on
xarray_jax.Dataset and xarray_jax.DataArray for more information.
This function can be used to set jax_coords on an existing DataArray or
Dataset, and also to set a mix of jax and non-jax coordinates. It implements
some workarounds to prevent xarray trying and failing to create IndexVariables
from jax arrays under the hood.
If you have any jax_coords with the same name as a dimension, you'll need to
use this function instead of data_array.assign_coords or dataset.assign_coords
in general, to avoid an xarray bug where it tries (and in our case fails) to
create indexes for existing jax coords. See
https://github.com/pydata/xarray/issues/7885.
Args:
x: An xarray Dataset or DataArray.
coords: Dict of (non-JAX) coords, or None if not assigning any.
jax_coords: Dict of JAX coords, or None if not assigning any. See docs for
xarray_jax.Dataset / DataArray for more information on jax_coords.
Returns:
The Dataset or DataArray with coordinates assigned, similarly to
Dataset.assign_coords / DataArray.assign_coords.
"""
coords = {} if coords is None else dict(coords) # Copy before mutating.
jax_coords = {} if jax_coords is None else dict(jax_coords)
# Any existing JAX coords must be dropped and re-added via the workaround
# below, since otherwise .assign_coords will trigger an xarray bug where
# it tries to recreate the indexes again for the existing coordinates.
# Can remove if/when https://github.com/pydata/xarray/issues/7885 fixed.
existing_jax_coords = get_jax_coords(x)
jax_coords = existing_jax_coords | jax_coords
x = x.drop_vars(existing_jax_coords.keys())
# We need to ensure that xarray doesn't try to create an index for
# coordinates with the same name as a dimension, since this will fail if
# given a wrapped JAX tracer.
# It appears the only way to avoid this is to name them differently to any
# dimension name, then rename them back afterwards.
renamed_jax_coords = {}
for name, coord in jax_coords.items():
if isinstance(coord, xarray.DataArray):
coord = coord.variable
if isinstance(coord, xarray.Variable):
coord = coord.copy(deep=False) # Copy before mutating attrs.
else:
# Must wrap as Variable with the correct dims first if this has not
# already been done, otherwise xarray.Dataset will assume the dimension
# name is also __NONINDEX_{n}.
coord = Variable((name,), coord)
# We set an attr on each jax_coord identifying it as such. These attrs on
# the coord Variable gets reflected on the coord DataArray exposed too, and
# when set on coordinates they generally get preserved under the default
# keep_attrs setting.
# These attrs are used by jax.tree_util registered flatten/unflatten to
# determine which coords need to be treated as leaves of the flattened
# structure vs static data.
coord.attrs[_JAX_COORD_ATTR_NAME] = True
renamed_jax_coords[f'__NONINDEX_{name}'] = coord
x = x.assign_coords(coords=coords | renamed_jax_coords)
rename_back_mapping = {f'__NONINDEX_{name}': name for name in jax_coords}
if isinstance(x, xarray.Dataset):
# Using 'rename' doesn't work if renaming to the same name as a dimension.
return x.rename_vars(rename_back_mapping)
else: # DataArray
return x.rename(rename_back_mapping)
def get_jax_coords(x: DatasetOrDataArray) -> Mapping[Hashable, Any]:
return {
name: coord_var
for name, coord_var in x.coords.variables.items()
if coord_var.attrs.get(_JAX_COORD_ATTR_NAME, False)}
def assign_jax_coords(
x: DatasetOrDataArray,
jax_coords: Optional[Mapping[Hashable, Any]] = None,
**jax_coords_kwargs
) -> DatasetOrDataArray:
"""Assigns only jax_coords, with same API as xarray's assign_coords."""
return assign_coords(x, jax_coords=jax_coords or jax_coords_kwargs)
def wrap(value):
"""Wraps JAX arrays for use in xarray, passing through other values."""
if isinstance(value, jax.Array):
return JaxArrayWrapper(value)
else:
return value
def unwrap(value, require_jax=False):
"""Unwraps wrapped JAX arrays used in xarray, passing through other values."""
if isinstance(value, JaxArrayWrapper):
return value.jax_array
elif isinstance(value, jax.Array):
return value
elif require_jax:
raise TypeError(f'Expected JAX array, found {type(value)}.')
else:
return value
def _wrapped(func):
"""Surrounds a function with JAX array unwrapping/wrapping."""
def wrapped_func(*args, **kwargs):
args, kwargs = tree.map_structure(unwrap, (args, kwargs))
result = func(*args, **kwargs)
return tree.map_structure(wrap, result)
return wrapped_func
def unwrap_data(
value: Union[xarray.Variable, xarray.DataArray],
require_jax: bool = False
) -> Union[jax.Array, np.ndarray]:
"""The unwrapped (see unwrap) data of a an xarray.Variable or DataArray."""
return unwrap(value.data, require_jax=require_jax)
def unwrap_vars(
dataset: Mapping[Hashable, xarray.DataArray],
require_jax: bool = False
) -> Mapping[str, Union[jax.Array, np.ndarray]]:
"""The unwrapped data (see unwrap) of the variables in a dataset."""
# xarray types variable names as Hashable, but in practice they're invariably
# strings and we convert to str to allow for a more useful return type.
return {str(name): unwrap_data(var, require_jax=require_jax)
for name, var in dataset.items()}
def unwrap_coords(
dataset: Union[xarray.Dataset, xarray.DataArray],
require_jax: bool = False
) -> Mapping[str, Union[jax.Array, np.ndarray]]:
"""The unwrapped data (see unwrap) of the coords in a Dataset or DataArray."""
return {str(name): unwrap_data(var, require_jax=require_jax)
for name, var in dataset.coords.items()}
def jax_data(value: Union[xarray.Variable, xarray.DataArray]) -> jax.Array:
"""Like unwrap_data, but will complain if not a jax array."""
# Implementing this separately so we can give a more specific return type
# for it.
return cast(jax.Array, unwrap_data(value, require_jax=True))
def jax_vars(
dataset: Mapping[Hashable, xarray.DataArray]) -> Mapping[str, jax.Array]:
"""Like unwrap_vars, but will complain if vars are not all jax arrays."""
return cast(Mapping[str, jax.Array], unwrap_vars(dataset, require_jax=True))
class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
"""Wraps a JAX array into a duck-typed array suitable for use with xarray.
This uses an older duck-typed array protocol based on __array_ufunc__ and
__array_function__ which works with numpy and xarray. (In newer versions
of xarray it implements xarray.namedarray._typing._array_function.)
This is in the process of being superseded by the Python array API standard
(https://data-apis.org/array-api/latest/index.html), but JAX hasn't
implemented it yet. Once they have, we should be able to get rid of
this wrapper and use JAX arrays directly with xarray.
"""
def __init__(self, jax_array):
self.jax_array = jax_array
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
for x in args:
if not isinstance(x, (jax.typing.ArrayLike, type(self))):
return NotImplemented
if method != '__call__':
return NotImplemented
try:
# Get the corresponding jax.numpy function to the NumPy ufunc:
func = getattr(jnp, ufunc.__name__)
except AttributeError:
return NotImplemented
# There may be an 'out' kwarg requesting an in-place operation, e.g. when
# this is called via __iadd__ (+=), __imul__ (*=) etc. JAX doesn't support
# in-place operations so we just remove this argument and have the ufunc
# return a fresh JAX array instead.
kwargs.pop('out', None)
return _wrapped(func)(*args, **kwargs)
def __array_function__(self, func, types, args, kwargs):
try:
# Get the corresponding jax.np function to the NumPy function:
func = getattr(jnp, func.__name__)
except AttributeError:
return NotImplemented
return _wrapped(func)(*args, **kwargs)
def __repr__(self):
return f'xarray_jax.JaxArrayWrapper({repr(self.jax_array)})'
# NDArrayOperatorsMixin already proxies most __dunder__ operator methods.
# We need to proxy through a few more methods in a similar way:
# Essential array properties:
@property
def shape(self):
return self.jax_array.shape
@property
def dtype(self):
return self.jax_array.dtype
@property
def ndim(self):
return self.jax_array.ndim
@property
def size(self):
return self.jax_array.size
@property
def real(self):
return self.jax_array.real
@property
def imag(self):
return self.jax_array.imag
# Array methods not covered by NDArrayOperatorsMixin:
# Allows conversion to numpy array using np.asarray etc. Warning: doing this
# will fail in a jax.jit-ed function.
def __array__(self, dtype=None, context=None):
return np.asarray(self.jax_array, dtype=dtype)
__getitem__ = _wrapped(lambda array, *args: array.__getitem__(*args))
# We drop the kwargs on this as they are not supported by JAX, but xarray
# uses at least one of them (the copy arg).
astype = _wrapped(lambda array, *args, **kwargs: array.astype(*args))
# There are many more methods which are more canonically available via (j)np
# functions, e.g. .sum() available via jnp.sum, and also mean, max, min,
# argmax, argmin etc. We don't attempt to proxy through all of these as
# methods, since this doesn't appear to be expected from a duck-typed array
# implementation. But there are a few which xarray calls as methods, so we
# proxy those:
transpose = _wrapped(jnp.transpose)
reshape = _wrapped(jnp.reshape)
all = _wrapped(jnp.all)
def apply_ufunc(func, *args, require_jax=False, **apply_ufunc_kwargs):
"""Like xarray.apply_ufunc but for jax-specific ufuncs.
Many numpy ufuncs will work fine out of the box with xarray_jax and
JaxArrayWrapper, since JaxArrayWrapper quacks (mostly) like a numpy array and
will convert many numpy operations to jax ops under the hood. For these
situations, xarray.apply_ufunc should work fine.
But sometimes you need a jax-specific ufunc which needs to be given a
jax array as input or return a jax array as output. In that case you should
use this helper as it will remove any JaxArrayWrapper before calling the func,
and wrap the result afterwards before handing it back to xarray.
Args:
func: A function that works with jax arrays (e.g. using functions from
jax.numpy) but otherwise meets the spec for the func argument to
xarray.apply_ufunc.
*args: xarray arguments to be mapped to arguments for func
(see xarray.apply_ufunc).
require_jax: Whether to require that inputs are based on jax arrays or allow
those based on plain numpy arrays too.
**apply_ufunc_kwargs: See xarray.apply_ufunc.
Returns:
Corresponding xarray results (see xarray.apply_ufunc).
"""
def wrapped_func(*maybe_wrapped_args):
unwrapped_args = [unwrap(a, require_jax) for a in maybe_wrapped_args]
result = func(*unwrapped_args)
# Result can be an array or a tuple of arrays, this handles both:
return jax.tree_util.tree_map(wrap, result)
return xarray.apply_ufunc(wrapped_func, *args, **apply_ufunc_kwargs)
def pmap(fn: Callable[..., Any],
dim: str,
axis_name: Optional[str] = None,
devices: ... = None,
backend: ... = None) -> Callable[..., Any]:
"""Wraps a subset of jax.pmap functionality to handle xarray input/output.
Constraints:
* Any Dataset or DataArray passed to the function must have `dim` as the
first dimension. This will be checked. You can ensure this if necessary
by calling `.transpose(dim, ...)` beforehand.
* All args and return values will be mapped over the first dimension,
it will use in_axes=0, out_axes=0.
* No support for static_broadcasted_argnums, donate_argnums etc.
Args:
fn: Function to be pmap'd which takes and returns trees which may contain
xarray Dataset/DataArray. Any Dataset/DataArrays passed as input must use
`dim` as the first dimension on all arrays.
dim: The xarray dimension name corresponding to the first dimension that is
pmapped over (pmap is called with in_axes=0, out_axes=0).
axis_name: Used by jax to identify the mapped axis so that parallel
collectives can be applied. Defaults to same as `dim`.
devices:
backend:
See jax.pmap.
Returns:
A pmap'd version of `fn`, which takes and returns Dataset/DataArray with an
extra leading dimension `dim` relative to what the original `fn` sees.
"""
input_treedef = None
output_treedef = None
def fn_passed_to_pmap(*flat_args):
assert input_treedef is not None
# Inside the pmap the original first dimension will no longer be present:
def check_and_remove_leading_dim(dims):
try:
index = dims.index(dim)
except ValueError:
index = None
if index != 0:
raise ValueError(f'Expected dim {dim} at index 0, found at {index}.')
return dims[1:]
with dims_change_on_unflatten(check_and_remove_leading_dim):
args = jax.tree_util.tree_unflatten(input_treedef, flat_args)
result = fn(*args)
nonlocal output_treedef
flat_result, output_treedef = jax.tree_util.tree_flatten(result)
return flat_result
pmapped_fn = jax.pmap(
fn_passed_to_pmap,
axis_name=axis_name or dim,
in_axes=0,
out_axes=0,
devices=devices,
backend=backend)
def result_fn(*args):
nonlocal input_treedef
flat_args, input_treedef = jax.tree_util.tree_flatten(args)
flat_result = pmapped_fn(*flat_args)
assert output_treedef is not None
# After the pmap an extra leading axis will be present, we need to add an
# xarray dimension for this when unflattening the result:
with dims_change_on_unflatten(lambda dims: (dim,) + dims):
return jax.tree_util.tree_unflatten(output_treedef, flat_result)
return result_fn
# Register xarray datatypes with jax.tree_util.
DimsChangeFn = Callable[[Tuple[Hashable, ...]], Tuple[Hashable, ...]]
_DIMS_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[DimsChangeFn] = (
contextvars.ContextVar('dims_change_on_unflatten_fn'))
@contextlib.contextmanager
def dims_change_on_unflatten(dims_change_fn: DimsChangeFn):
"""Can be used to change the dims used when unflattening arrays into xarrays.
This is useful when some axes were added to / removed from the underlying jax
arrays after they were flattened using jax.tree_util.tree_flatten, and you
want to unflatten them again afterwards using the original treedef but
adjusted for the added/removed dimensions.
It can also be used with jax.tree_util.tree_map, when it's called with a
function that adds/removes axes or otherwise changes the axis order.
When dimensions are removed, any coordinates using those removed dimensions
will also be removed on unflatten.
This is implemented as a context manager that sets some thread-local state
affecting the behaviour of our unflatten functions, because it's not possible
to directly modify the treedef to change the dims/coords in it (and with
tree_map, the treedef isn't exposed to you anyway).
Args:
dims_change_fn: Maps a tuple of dimension names for the original
Variable/DataArray/Dataset that was flattened, to an updated tuple of
dimensions which should be used when unflattening.
Yields:
To a context manager in whose scope jax.tree_util.tree_unflatten and
jax.tree_util.tree_map will apply the dims_change_fn before reconstructing
xarrays from jax arrays.
"""
token = _DIMS_CHANGE_ON_UNFLATTEN_FN.set(dims_change_fn)
try:
yield
finally:
_DIMS_CHANGE_ON_UNFLATTEN_FN.reset(token)
def _flatten_variable(v: xarray.Variable) -> Tuple[
Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]:
"""Flattens a Variable for jax.tree_util."""
children = (unwrap_data(v),)
aux = v.dims
return children, aux
def _unflatten_variable(
aux: Tuple[Hashable, ...],
children: Tuple[jax.typing.ArrayLike]) -> xarray.Variable:
"""Unflattens a Variable for jax.tree_util."""
dims = aux
dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
if dims_change_fn: dims = dims_change_fn(dims)
return Variable(dims=dims, data=children[0])
def _split_static_and_jax_coords(
coords: xarray.core.coordinates.Coordinates) -> Tuple[
Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]]:
static_coord_vars = {}
jax_coord_vars = {}
for name, coord in coords.items():
if coord.attrs.get(_JAX_COORD_ATTR_NAME, False):
jax_coord_vars[name] = coord.variable
else:
assert not isinstance(coord, (jax.Array, JaxArrayWrapper))
static_coord_vars[name] = coord.variable
return static_coord_vars, jax_coord_vars
def _drop_with_none_of_dims(
coord_vars: Mapping[Hashable, xarray.Variable],
dims: Tuple[Hashable]) -> Mapping[Hashable, xarray.Variable]:
return {name: var for name, var in coord_vars.items()
if set(var.dims) <= set(dims)}
class _HashableCoords(collections.abc.Mapping):
"""Wraps a dict of xarray Variables as hashable, used for static coordinates.
This needs to be hashable so that when an xarray.Dataset is passed to a
jax.jit'ed function, jax can check whether it's seen an array with the
same static coordinates(*) before or whether it needs to recompile the
function for the new values of the static coordinates.
(*) note jax_coords are not included in this; their value can be different
on different calls without triggering a recompile.
"""
def __init__(self, coord_vars: Mapping[Hashable, xarray.Variable]):
self._variables = coord_vars
def __repr__(self) -> str:
return f'_HashableCoords({repr(self._variables)})'
def __getitem__(self, key: Hashable) -> xarray.Variable:
return self._variables[key]
def __len__(self) -> int:
return len(self._variables)
def __iter__(self) -> Iterator[Hashable]:
return iter(self._variables)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(frozenset((name, var.data.tobytes())
for name, var in self._variables.items()))
return self._hash
def __eq__(self, other):
if self is other:
return True
elif not isinstance(other, type(self)):
return NotImplemented
elif self._variables is other._variables:
return True
else:
return self._variables.keys() == other._variables.keys() and all(
variable.equals(other._variables[name])
for name, variable in self._variables.items())
def _flatten_data_array(v: xarray.DataArray) -> Tuple[
# Children (data variable, jax_coord_vars):
Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
# Static auxiliary data (name, static_coord_vars):
Tuple[Optional[Hashable], _HashableCoords]]:
"""Flattens a DataArray for jax.tree_util."""
static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(v.coords)
children = (v.variable, jax_coord_vars)
aux = (v.name, _HashableCoords(static_coord_vars))
return children, aux
def _unflatten_data_array(
aux: Tuple[Optional[Hashable], _HashableCoords],
children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
) -> xarray.DataArray:
"""Unflattens a DataArray for jax.tree_util."""
variable, jax_coord_vars = children
name, static_coord_vars = aux
# Drop static coords which have dims not present in any of the data_vars.
# These would generally be dims that were dropped by a dims_change_fn, but
# because static coordinates don't go through dims_change_fn on unflatten, we
# just drop them where this causes a problem.
# Since jax_coords go through the dims_change_fn on unflatten we don't need
# to do this for jax_coords.
static_coord_vars = _drop_with_none_of_dims(static_coord_vars, variable.dims)
return DataArray(
variable, name=name, coords=static_coord_vars, jax_coords=jax_coord_vars)
def _flatten_dataset(dataset: xarray.Dataset) -> Tuple[
# Children (data variables, jax_coord_vars):
Tuple[Mapping[Hashable, xarray.Variable],
Mapping[Hashable, xarray.Variable]],
# Static auxiliary data (static_coord_vars):
_HashableCoords]:
"""Flattens a Dataset for jax.tree_util."""
variables = {name: data_array.variable
for name, data_array in dataset.data_vars.items()}
static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(
dataset.coords)
children = (variables, jax_coord_vars)
aux = _HashableCoords(static_coord_vars)
return children, aux
def _unflatten_dataset(
aux: _HashableCoords,
children: Tuple[Mapping[Hashable, xarray.Variable],
Mapping[Hashable, xarray.Variable]],
) -> xarray.Dataset:
"""Unflattens a Dataset for jax.tree_util."""
data_vars, jax_coord_vars = children
static_coord_vars = aux
dataset = xarray.Dataset(data_vars)
# Drop static coords which have dims not present in any of the data_vars.
# See corresponding comment in _unflatten_data_array.
static_coord_vars = _drop_with_none_of_dims(static_coord_vars, dataset.dims) # pytype: disable=wrong-arg-types
return assign_coords(
dataset, coords=static_coord_vars, jax_coords=jax_coord_vars)
jax.tree_util.register_pytree_node(
xarray.Variable, _flatten_variable, _unflatten_variable)
# This is a subclass of Variable but still needs registering separately.
# Flatten/unflatten for IndexVariable is a bit of a corner case but we do
# need to support it.
jax.tree_util.register_pytree_node(
xarray.IndexVariable, _flatten_variable, _unflatten_variable)
jax.tree_util.register_pytree_node(
xarray.DataArray, _flatten_data_array, _unflatten_data_array)
jax.tree_util.register_pytree_node(
xarray.Dataset, _flatten_dataset, _unflatten_dataset)
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_jax."""
from absl.testing import absltest
import chex
from graphcast import xarray_jax
import jax
import jax.numpy as jnp
import numpy as np
import xarray
class XarrayJaxTest(absltest.TestCase):
def test_jax_array_wrapper_with_numpy_api(self):
# This is just a side benefit of making things work with xarray, but the
# JaxArrayWrapper does allow you to manipulate JAX arrays using the
# standard numpy API, without converting them to numpy in the process:
ones = jnp.ones((3, 4), dtype=np.float32)
x = xarray_jax.JaxArrayWrapper(ones)
x = np.abs((x + 2) * (x - 3))
x = x[:-1, 1:3]
x = np.concatenate([x, x + 1], axis=0)
x = np.transpose(x, (1, 0))
x = np.reshape(x, (-1,))
x = x.astype(np.int32)
self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
# An explicit conversion gets us out of JAX-land however:
self.assertIsInstance(np.asarray(x), np.ndarray)
def test_jax_xarray_variable(self):
def ops_via_xarray(inputs):
x = xarray_jax.Variable(('lat', 'lon'), inputs)
# We'll apply a sequence of operations just to test that the end result is
# still a JAX array, i.e. we haven't converted to numpy at any point.
x = np.abs((x + 2) * (x - 3))
x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
x = xarray.Variable.concat([x, x + 1], dim='lat')
x = x.transpose('lon', 'lat')
x = x.stack(channels=('lon', 'lat'))
x = x.sum()
return xarray_jax.jax_data(x)
# Check it doesn't leave jax-land when passed concrete values:
ones = jnp.ones((3, 4), dtype=np.float32)
result = ops_via_xarray(ones)
self.assertIsInstance(result, jax.Array)
# And that you can JIT it and compute gradients through it. These will
# involve passing jax tracers through the xarray computation:
jax.jit(ops_via_xarray)(ones)
jax.grad(ops_via_xarray)(ones)
def test_jax_xarray_data_array(self):
def ops_via_xarray(inputs):
x = xarray_jax.DataArray(dims=('lat', 'lon'),
data=inputs,
coords={'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
x = np.abs((x + 2) * (x - 3))
x = x.sel({'lat': slice(0, 20)})
y = xarray_jax.DataArray(dims=('lat', 'lon'),
data=ones,
coords={'lat': np.arange(3, 6) * 10,
'lon': np.arange(4) * 10})
x = xarray.concat([x, y], dim='lat')
x = x.transpose('lon', 'lat')
x = x.stack(channels=('lon', 'lat'))
x = x.unstack()
x = x.sum()
return xarray_jax.jax_data(x)
ones = jnp.ones((3, 4), dtype=np.float32)
result = ops_via_xarray(ones)
self.assertIsInstance(result, jax.Array)
jax.jit(ops_via_xarray)(ones)
jax.grad(ops_via_xarray)(ones)
def test_jax_xarray_dataset(self):
def ops_via_xarray(foo, bar):
x = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
x = np.abs((x + 2) * (x - 3))
x = x.sel({'lat': slice(0, 20)})
y = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3, 6) * 10,
'lon': np.arange(4) * 10})
x = xarray.concat([x, y], dim='lat')
x = x.transpose('lon', 'lat', 'time')
x = x.stack(channels=('lon', 'lat'))
x = (x.foo + x.bar).sum()
return xarray_jax.jax_data(x)
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
result = ops_via_xarray(foo, bar)
self.assertIsInstance(result, jax.Array)
jax.jit(ops_via_xarray)(foo, bar)
jax.grad(ops_via_xarray)(foo, bar)
def test_jit_function_with_xarray_variable_arguments_and_return(self):
function = jax.jit(lambda v: v + 1)
with self.subTest('jax input'):
inputs = xarray_jax.Variable(
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
_ = function(inputs)
# We test running the jitted function a second time, to exercise logic in
# jax which checks if the structure of the inputs (including dimension
# names and coordinates) is the same as it was for the previous call and
# so whether it needs to re-trace-and-compile a new version of the
# function or not. This can run into problems if the 'aux' structure
# returned by the registered flatten function is not hashable/comparable.
outputs = function(inputs)
self.assertEqual(outputs.dims, inputs.dims)
with self.subTest('numpy input'):
inputs = xarray.Variable(
('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
_ = function(inputs)
outputs = function(inputs)
self.assertEqual(outputs.dims, inputs.dims)
def test_jit_problem_if_convert_to_plain_numpy_array(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
with self.assertRaises(jax.errors.TracerArrayConversionError):
# Calling .values on a DataArray converts its values to numpy:
jax.jit(lambda data_array: data_array.values)(inputs)
def test_grad_function_with_xarray_variable_arguments(self):
x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
# For grad we still need a JAX scalar as the output:
jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
def test_jit_function_with_xarray_data_array_arguments_and_return(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
fn = jax.jit(lambda v: v + 1)
_ = fn(inputs)
outputs = fn(inputs)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
def test_jit_function_with_data_array_and_jax_coords(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3)},
jax_coords={'lon': jnp.arange(4) * 10})
# Verify the jax_coord 'lon' retains jax data, and has not been created
# as an index coordinate:
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', inputs.indexes)
@jax.jit
def fn(v):
# The non-JAX coord is passed with numpy array data and an index:
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
self.assertIn('lat', v.indexes)
# The jax_coord is passed with JAX array data:
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', v.indexes)
# Use the jax coord in the computation:
v = v + v.coords['lon']
# Return with an updated jax coord:
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
_ = fn(inputs)
outputs = fn(inputs)
# Verify the jax_coord 'lon' has jax data in the output too:
self.assertIsInstance(
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', outputs.indexes)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
# Check our computations with the coordinate values worked:
chex.assert_trees_all_equal(
outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
chex.assert_trees_all_equal(
outputs.data, (inputs + inputs.coords['lon']).data)
def test_jit_function_with_xarray_dataset_arguments_and_return(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
inputs = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
fn = jax.jit(lambda v: v + 1)
_ = fn(inputs)
outputs = fn(inputs)
self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
self.assertEqual(inputs.foo.dims, outputs.foo.dims)
self.assertEqual(inputs.bar.dims, outputs.bar.dims)
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
def test_jit_function_with_dataset_and_jax_coords(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
inputs = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
},
jax_coords={'lon': jnp.arange(4) * 10}
)
# Verify the jax_coord 'lon' retains jax data, and has not been created
# as an index coordinate:
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', inputs.indexes)
@jax.jit
def fn(v):
# The non-JAX coords are passed with numpy array data and an index:
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
self.assertIn('lat', v.indexes)
# The jax_coord is passed with JAX array data:
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', v.indexes)
# Use the jax coord in the computation:
v = v + v.coords['lon']
# Return with an updated jax coord:
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
_ = fn(inputs)
outputs = fn(inputs)
# Verify the jax_coord 'lon' has jax data in the output too:
self.assertIsInstance(
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', outputs.indexes)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
# Check our computations with the coordinate values worked:
chex.assert_trees_all_equal(
(outputs.coords['lon']).data,
(inputs.coords['lon']+1).data,
)
outputs_dict = {key: outputs[key].data for key in outputs}
inputs_and_inputs_coords_dict = {
key: (inputs + inputs.coords['lon'])[key].data
for key in inputs + inputs.coords['lon']
}
chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
def test_flatten_unflatten_variable(self):
variable = xarray_jax.Variable(
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
children, aux = xarray_jax._flatten_variable(variable)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_variable(aux, children)
self.assertTrue(variable.equals(roundtrip))
def test_flatten_unflatten_data_array(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3)},
jax_coords={'lon': np.arange(4) * 10},
)
children, aux = xarray_jax._flatten_data_array(data_array)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_data_array(aux, children)
self.assertTrue(data_array.equals(roundtrip))
def test_flatten_unflatten_dataset(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10},
jax_coords={
'lon': np.arange(4) * 10})
children, aux = xarray_jax._flatten_dataset(dataset)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_dataset(aux, children)
self.assertTrue(dataset.equals(roundtrip))
def test_flatten_unflatten_added_dim(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
leaves, treedef = jax.tree_util.tree_flatten(data_array)
leaves = [jnp.expand_dims(x, 0) for x in leaves]
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
xarray.testing.assert_identical(
jax.device_get(data_array),
jax.device_get(with_new_dim.isel(new=0)))
def test_map_added_dim(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
data_array)
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
xarray.testing.assert_identical(
jax.device_get(data_array),
jax.device_get(with_new_dim.isel(new=0)))
def test_map_remove_dim(self):
foo = jnp.ones((1, 3, 4), dtype=np.float32)
bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset(
data_vars={'foo': (('batch', 'lat', 'lon'), foo),
'bar': (('batch', 'time', 'lat', 'lon'), bar)},
coords={
'batch': np.array([123]),
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
dataset)
self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
self.assertNotIn('batch', with_removed_dim.dims)
self.assertNotIn('batch', with_removed_dim.coords)
xarray.testing.assert_identical(
jax.device_get(dataset.isel(batch=0, drop=True)),
jax.device_get(with_removed_dim))
def test_pmap(self):
devices = jax.local_device_count()
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset({
'foo': (('device', 'lat', 'lon'), foo),
'bar': (('device', 'time', 'lat', 'lon'), bar)})
def func(d):
self.assertNotIn('device', d.dims)
return d + 1
func = xarray_jax.pmap(func, dim='device')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
# Can call it again with a different argument structure (it will recompile
# under the hood but should work):
dataset = dataset.drop_vars('foo')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
def test_pmap_with_jax_coords(self):
devices = jax.local_device_count()
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
time = jnp.zeros((devices, 2), dtype=np.float32)
dataset = xarray_jax.Dataset(
{'foo': (('device', 'lat', 'lon'), foo),
'bar': (('device', 'time', 'lat', 'lon'), bar)},
coords={
'lat': np.arange(3),
'lon': np.arange(4),
},
jax_coords={
# Currently any jax_coords need a leading device dimension to use
# with pmap, same as for data_vars.
# TODO(matthjw): have pmap automatically broadcast to all devices
# where the device dimension not present.
'time': xarray_jax.Variable(('device', 'time'), time),
}
)
def func(d):
self.assertNotIn('device', d.dims)
self.assertNotIn('device', d.coords['time'].dims)
# The jax_coord 'time' should be passed in backed by a JAX array, but
# not as an index coordinate.
self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('time', d.indexes)
return d + 1
func = xarray_jax.pmap(func, dim='device')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
# Can call it again with a different argument structure (it will recompile
# under the hood but should work):
dataset = dataset.drop_vars('foo')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
devices = jax.local_device_count()
data_array = xarray_jax.DataArray(
data=jnp.ones((devices, 3, 4), dtype=np.float32),
dims=('device', 'lat', 'lon'))
plain_array = jnp.ones((devices, 2), dtype=np.float32)
inputs = {'foo': data_array,
'bar': plain_array}
def func(x):
return x['foo'] + 1, x['bar'] + 1
func = xarray_jax.pmap(func, dim='device')
result_foo, result_bar = func(inputs)
xarray.testing.assert_identical(
jax.device_get(inputs['foo'] + 1),
jax.device_get(result_foo))
np.testing.assert_array_equal(
jax.device_get(inputs['bar'] + 1),
jax.device_get(result_bar))
def test_pmap_complains_when_dim_not_first(self):
devices = jax.local_device_count()
data_array = xarray_jax.DataArray(
data=jnp.ones((3, devices, 4), dtype=np.float32),
dims=('lat', 'device', 'lon'))
func = xarray_jax.pmap(lambda x: x+1, dim='device')
with self.assertRaisesRegex(
ValueError, 'Expected dim device at index 0, found at 1'):
func(data_array)
def test_apply_ufunc(self):
inputs = xarray_jax.DataArray(
data=jnp.asarray([[1, 2], [3, 4]]),
dims=('x', 'y'),
coords={'x': [0, 1],
'y': [2, 3]})
result = xarray_jax.apply_ufunc(
lambda x: jnp.sum(x, axis=-1),
inputs,
input_core_dims=[['x']])
expected_result = xarray_jax.DataArray(
data=[4, 6],
dims=('y',),
coords={'y': [2, 3]})
xarray.testing.assert_identical(expected_result, jax.device_get(result))
def test_apply_ufunc_multiple_return_values(self):
def ufunc(array):
return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
inputs = xarray_jax.DataArray(
data=jnp.asarray([[1, 4], [3, 2]]),
dims=('x', 'y'),
coords={'x': [0, 1],
'y': [2, 3]})
result = xarray_jax.apply_ufunc(
ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
expected = (
# Mins:
xarray_jax.DataArray(
data=[1, 2],
dims=('y',),
coords={'y': [2, 3]}
),
# Maxes:
xarray_jax.DataArray(
data=[3, 4],
dims=('y',),
coords={'y': [2, 3]}
)
)
xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
if __name__ == '__main__':
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with trees of xarray.DataArray (including Datasets).
Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
it won't work as a leaf node since it implements Mapping, but also won't work
as an internal node since tree doesn't know how to re-create it properly.
To fix this, we reimplement a subset of `map_structure`, exposing its
constituent DataArrays as leaf nodes. This means it can be mapped over as a
generic container of DataArrays, while still preserving the result as a Dataset
where possible.
This is useful because in a few places we need to handle a general
Mapping[str, DataArray] (where the coordinates might not be compatible across
the constituent DataArrays) but also the special case of a Dataset nicely.
For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
some of the child DataArrays, they will be omitted from the returned dataset. If
any values other than DataArrays or None are returned, then we don't attempt to
return a Dataset and just return a plain dict of the results. Similarly if
DataArrays are returned but with non-matching coordinates, it will just return a
plain dict of DataArrays.
Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
latter exposes DataArrays as leaf nodes.
"""
from typing import Any, Callable
import xarray
def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
"""Maps func through given structures with xarrays. See tree.map_structure."""
if not callable(func):
raise TypeError(f'func must be callable, got: {func}')
if not structures:
raise ValueError('Must provide at least one structure')
first = structures[0]
if isinstance(first, xarray.Dataset):
data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
if all(isinstance(a, (type(None), xarray.DataArray))
for a in data.values()):
data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
try:
return xarray.merge(data_arrays, join='exact')
except ValueError: # Exact join not possible.
pass
return data
if isinstance(first, dict):
return {k: map_structure(func, *[s[k] for s in structures])
for k in first.keys()}
if isinstance(first, (list, tuple, set)):
return type(first)(map_structure(func, *s) for s in zip(*structures))
return func(*structures)
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_tree."""
from absl.testing import absltest
from graphcast import xarray_tree
import numpy as np
import xarray
TEST_DATASET = xarray.Dataset(
data_vars={
"foo": (("x", "y"), np.zeros((2, 3))),
"bar": (("x",), np.zeros((2,))),
},
coords={
"x": [1, 2],
"y": [10, 20, 30],
}
)
class XarrayTreeTest(absltest.TestCase):
def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
def fn(leaf):
self.assertIsInstance(leaf, xarray.DataArray)
result = leaf + 1
# Removing the name from the returned DataArray to test that we don't rely
# on it being present to restore the correct names in the result:
result = result.rename(None)
return result
result = xarray_tree.map_structure(fn, TEST_DATASET)
self.assertIsInstance(result, xarray.Dataset)
self.assertSameElements({"foo", "bar"}, result.keys())
def test_map_structure_on_data_arrays(self):
data_arrays = dict(TEST_DATASET)
result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
self.assertIsInstance(result, dict)
self.assertSameElements({"foo", "bar"}, result.keys())
def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
def fn(leaf):
# Returns DataArrays that can't be exactly merged back into a Dataset
# due to the coordinates not matching:
if leaf.name == "foo":
return xarray.DataArray(
data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
else:
return xarray.DataArray(
data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
result = xarray_tree.map_structure(fn, TEST_DATASET)
self.assertIsInstance(result, dict)
self.assertSameElements({"foo", "bar"}, result.keys())
def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
def fn(leaf):
return leaf if leaf.name == "foo" else None
result = xarray_tree.map_structure(fn, TEST_DATASET)
self.assertIsInstance(result, xarray.Dataset)
self.assertSameElements({"foo"}, result.keys())
def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
def fn(leaf):
self.assertIsInstance(leaf, xarray.DataArray)
return "not a DataArray"
result = xarray_tree.map_structure(fn, TEST_DATASET)
self.assertEqual({"foo": "not a DataArray",
"bar": "not a DataArray"}, result)
def test_map_structure_two_args_different_variable_orders(self):
dataset_different_order = TEST_DATASET[["bar", "foo"]]
def fn(arg1, arg2):
self.assertEqual(arg1.name, arg2.name)
xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
if __name__ == "__main__":
absltest.main()
This source diff could not be displayed because it is too large. You can view the blob instead.
# 模型唯一标识
modelCode=670
# 模型名称
modelName=graphcast_jax
# 模型描述
modelDescription=使用深度学习预测天气。
# 应用场景
appScenario=推理,训练,天气预报,气象,交通,环境
# 框架类型
frameType=jax
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module setuptools script."""
from setuptools import setup
description = (
"GraphCast: Learning skillful medium-range global weather forecasting"
)
setup(
name="graphcast",
version="0.1",
description=description,
long_description=description,
author="DeepMind",
license="Apache License, Version 2.0",
keywords="GraphCast Weather Prediction",
url="https://github.com/deepmind/graphcast",
packages=["graphcast"],
install_requires=[
"cartopy",
"chex",
"colabtools",
"dask",
"dm-haiku",
"dm-tree",
"jax",
"jraph",
"matplotlib",
"numpy",
"pandas",
"rtree",
"scipy",
"trimesh",
"typing_extensions",
"xarray",
],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Atmospheric Science",
"Topic :: Scientific/Engineering :: Physics",
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment