Commit 4d04d055 authored by mashun1's avatar mashun1
Browse files

graphcast

parents
Pipeline #1048 failed with stages
in 0 seconds
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for creating icosahedral meshes."""
import itertools
from typing import List, NamedTuple, Sequence, Tuple
import numpy as np
from scipy.spatial import transform
class TriangularMesh(NamedTuple):
"""Data structure for triangular meshes.
Attributes:
vertices: spatial positions of the vertices of the mesh of shape
[num_vertices, num_dims].
faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
integer indices into `vertices`.
"""
vertices: np.ndarray
faces: np.ndarray
def merge_meshes(
mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
"""Merges all meshes into one. Assumes the last mesh is the finest.
Args:
mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
vertices and faces may contain those from preceding, coarser levels.
Returns:
`TriangularMesh` for which the vertices correspond to the highest
resolution mesh in the hierarchy, and the faces are the join set of the
faces at all levels of the hierarchy.
"""
for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
num_nodes_mesh_i = mesh_i.vertices.shape[0]
assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
return TriangularMesh(
vertices=mesh_list[-1].vertices,
faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
def get_hierarchy_of_triangular_meshes_for_sphere(
splits: int) -> List[TriangularMesh]:
"""Returns a sequence of meshes, each with triangularization sphere.
Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
circumscribed unit sphere. Then, each triangular face is iteratively
subdivided into 4 triangular faces `splits` times. The new vertices are then
projected back onto the unit sphere. All resulting meshes are returned in a
list, from lowest to highest resolution.
The vertices in each face are specified in counter-clockwise order as
observed from the outside the icosahedron.
Args:
splits: How many times to split each triangle.
Returns:
Sequence of `TriangularMesh`s of length `splits + 1` each with:
vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
Each row contains three indices into the vertices array, indicating
the vertices adjacent to the face. Always with positive orientation
(counterclock-wise when looking from the outside).
"""
current_mesh = get_icosahedron()
output_meshes = [current_mesh]
for _ in range(splits):
current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
output_meshes.append(current_mesh)
return output_meshes
def get_icosahedron() -> TriangularMesh:
"""Returns a regular icosahedral mesh with circumscribed unit sphere.
See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
for details on the construction of the regular icosahedron.
The vertices in each face are specified in counter-clockwise order as observed
from the outside of the icosahedron.
Returns:
TriangularMesh with:
vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
Each row contains three indices into the vertices array, indicating
the vertices adjacent to the face. Always with positive orientation (
counterclock-wise when looking from the outside).
"""
phi = (1 + np.sqrt(5)) / 2
vertices = []
for c1 in [1., -1.]:
for c2 in [phi, -phi]:
vertices.append((c1, c2, 0.))
vertices.append((0., c1, c2))
vertices.append((c2, 0., c1))
vertices = np.array(vertices, dtype=np.float32)
vertices /= np.linalg.norm([1., phi])
# I did this manually, checking the orientation one by one.
faces = [(0, 1, 2),
(0, 6, 1),
(8, 0, 2),
(8, 4, 0),
(3, 8, 2),
(3, 2, 7),
(7, 2, 1),
(0, 4, 6),
(4, 11, 6),
(6, 11, 5),
(1, 5, 7),
(4, 10, 11),
(4, 8, 10),
(10, 8, 3),
(10, 3, 9),
(11, 10, 9),
(11, 9, 5),
(5, 9, 7),
(9, 3, 7),
(1, 6, 5),
]
# By default the top is an aris parallel to the Y axis.
# Need to rotate around the y axis by half the supplementary to the
# angle between faces divided by two to get the desired orientation.
# /O\ (top arist)
# / \ Z
# (adjacent face)/ \ (adjacent face) ^
# / angle_between_faces \ |
# / \ |
# / \ YO-----> X
# This results in:
# (adjacent faceis now top plane)
# ----------------------O\ (top arist)
# \
# \
# \ (adjacent face)
# \
# \
# \
angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
rotation_angle = (np.pi - angle_between_faces) / 2
rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
rotation_matrix = rotation.as_matrix()
vertices = np.dot(vertices, rotation_matrix)
return TriangularMesh(vertices=vertices.astype(np.float32),
faces=np.array(faces, dtype=np.int32))
def _two_split_unit_sphere_triangle_faces(
triangular_mesh: TriangularMesh) -> TriangularMesh:
"""Splits each triangular face into 4 triangles keeping the orientation."""
# Every time we split a triangle into 4 we will be adding 3 extra vertices,
# located at the edge centres.
# This class handles the positioning of the new vertices, and avoids creating
# duplicates.
new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
new_faces = []
for ind1, ind2, ind3 in triangular_mesh.faces:
# Transform each triangular face into 4 triangles,
# preserving the orientation.
# ind3
# / \
# / \
# / #3 \
# / \
# ind31 -------------- ind23
# / \ / \
# / \ #4 / \
# / #1 \ / #2 \
# / \ / \
# ind1 ------------ ind12 ------------ ind2
ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
# Note how each of the 4 triangular new faces specifies the order of the
# vertices to preserve the orientation of the original face. As the input
# face should always be counter-clockwise as specified in the diagram,
# this means child faces should also be counter-clockwise.
new_faces.extend([[ind1, ind12, ind31], # 1
[ind12, ind2, ind23], # 2
[ind31, ind23, ind3], # 3
[ind12, ind23, ind31], # 4
])
return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
faces=np.array(new_faces, dtype=np.int32))
class _ChildVerticesBuilder(object):
"""Bookkeeping of new child vertices added to an existing set of vertices."""
def __init__(self, parent_vertices):
# Because the same new vertex will be required when splitting adjacent
# triangles (which share an edge) we keep them in a hash table indexed by
# sorted indices of the vertices adjacent to the edge, to avoid creating
# duplicated child vertices.
self._child_vertices_index_mapping = {}
self._parent_vertices = parent_vertices
# We start with all previous vertices.
self._all_vertices_list = list(parent_vertices)
def _get_child_vertex_key(self, parent_vertex_indices):
return tuple(sorted(parent_vertex_indices))
def _create_child_vertex(self, parent_vertex_indices):
"""Creates a new vertex."""
# Position for new vertex is the middle point, between the parent points,
# projected to unit sphere.
child_vertex_position = self._parent_vertices[
list(parent_vertex_indices)].mean(0)
child_vertex_position /= np.linalg.norm(child_vertex_position)
# Add the vertex to the output list. The index for this new vertex will
# match the length of the list before adding it.
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
self._child_vertices_index_mapping[child_vertex_key] = len(
self._all_vertices_list)
self._all_vertices_list.append(child_vertex_position)
def get_new_child_vertex_index(self, parent_vertex_indices):
"""Returns index for a child vertex, creating it if necessary."""
# Get the key to see if we already have a new vertex in the middle.
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
if child_vertex_key not in self._child_vertices_index_mapping:
self._create_child_vertex(parent_vertex_indices)
return self._child_vertices_index_mapping[child_vertex_key]
def get_all_vertices(self):
"""Returns an array with old vertices."""
return np.array(self._all_vertices_list)
def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Transforms polygonal faces to sender and receiver indices.
It does so by transforming every face into N_i edges. Such if the triangular
face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
If all faces have consistent orientation, and the surface represented by the
faces is closed, then every edge in a polygon with a certain orientation
is also part of another polygon with the opposite orientation. In this
situation, the edges returned by the method are always bidirectional.
Args:
faces: Integer array of shape [num_faces, 3]. Contains node indices
adjacent to each face.
Returns:
Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
"""
assert faces.ndim == 2
assert faces.shape[-1] == 3
senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
return senders, receivers
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for icosahedral_mesh."""
from absl.testing import absltest
from absl.testing import parameterized
import chex
from graphcast import icosahedral_mesh
import numpy as np
def _get_mesh_spec(splits: int):
"""Returns size of the final icosahedral mesh resulting from the splitting."""
num_vertices = 12
num_faces = 20
for _ in range(splits):
# Each previous face adds three new vertices, but each vertex is shared
# by two faces.
num_vertices += num_faces * 3 // 2
num_faces *= 4
return num_vertices, num_faces
class IcosahedralMeshTest(parameterized.TestCase):
def test_icosahedron(self):
mesh = icosahedral_mesh.get_icosahedron()
_assert_valid_mesh(
mesh, num_expected_vertices=12, num_expected_faces=20)
@parameterized.parameters(list(range(5)))
def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=splits)
prev_vertices = None
for mesh_i, mesh in enumerate(meshes):
# Check that `mesh` is valid.
num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
_assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
# Check that the first N vertices from this mesh match all of the
# vertices from the previous mesh.
if prev_vertices is not None:
leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
# Increase the expected/previous values for the next iteration.
if mesh_i < len(meshes) - 1:
prev_vertices = mesh.vertices
@parameterized.parameters(list(range(4)))
def test_merge_meshes(self, splits):
mesh_hierarchy = (
icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=splits))
mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
np.testing.assert_array_equal(mesh.faces, expected_faces)
def test_faces_to_edges(self):
faces = np.array([[0, 1, 2],
[3, 4, 5]])
# This also documents the order of the edges returned by the method.
expected_edges = np.array(
[[0, 1],
[3, 4],
[1, 2],
[4, 5],
[2, 0],
[5, 3]])
expected_senders = expected_edges[:, 0]
expected_receivers = expected_edges[:, 1]
senders, receivers = icosahedral_mesh.faces_to_edges(faces)
np.testing.assert_array_equal(senders, expected_senders)
np.testing.assert_array_equal(receivers, expected_receivers)
def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
vertices = mesh.vertices
faces = mesh.faces
chex.assert_shape(vertices, [num_expected_vertices, 3])
chex.assert_shape(faces, [num_expected_faces, 3])
# Vertices norm should be 1.
vertices_norm = np.linalg.norm(vertices, axis=-1)
np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
_assert_positive_face_orientation(vertices, faces)
def _assert_positive_face_orientation(vertices, faces):
# Obtain a unit vector that points, in the direction of the face.
face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
vertices[faces[:, 2]] - vertices[faces[:, 1]])
face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
# And a unit vector pointing from the origin to the center of the face.
face_centers = vertices[faces].mean(1)
face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
# Positive orientation means those two vectors should be parallel
# (dot product, 1), and not anti-parallel (dot product, -1).
dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
# Check that the face normal is parallel to the vector that joins the center
# of the face to the center of the sphere. Note we need a small tolerance
# because some discretizations are not exactly uniform, so it will not be
# exactly parallel.
np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
if __name__ == "__main__":
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions (and terms for use in loss functions) used for weather."""
from typing import Mapping
from graphcast import xarray_tree
import numpy as np
from typing_extensions import Protocol
import xarray
LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
class LossFunction(Protocol):
"""A loss function.
This is a protocol so it's fine to use a plain function which 'quacks like'
this. This is just to document the interface.
"""
def __call__(self,
predictions: xarray.Dataset,
targets: xarray.Dataset,
**optional_kwargs) -> LossAndDiagnostics:
"""Computes a loss function.
Args:
predictions: Dataset of predictions.
targets: Dataset of targets.
**optional_kwargs: Implementations may support extra optional kwargs.
Returns:
loss: A DataArray with dimensions ('batch',) containing losses for each
element of the batch. These will be averaged to give the final
loss, locally and across replicas.
diagnostics: Mapping of additional quantities to log by name alongside the
loss. These will will typically correspond to terms in the loss. They
should also have dimensions ('batch',) and will be averaged over the
batch before logging.
"""
def weighted_mse_per_level(
predictions: xarray.Dataset,
targets: xarray.Dataset,
per_variable_weights: Mapping[str, float],
) -> LossAndDiagnostics:
"""Latitude- and pressure-level-weighted MSE loss."""
def loss(prediction, target):
loss = (prediction - target)**2
loss *= normalized_latitude_weights(target).astype(loss.dtype)
if 'level' in target.dims:
loss *= normalized_level_weights(target).astype(loss.dtype)
return _mean_preserving_batch(loss)
losses = xarray_tree.map_structure(loss, predictions, targets)
return sum_per_variable_losses(losses, per_variable_weights)
def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
def sum_per_variable_losses(
per_variable_losses: Mapping[str, xarray.DataArray],
weights: Mapping[str, float],
) -> LossAndDiagnostics:
"""Weighted sum of per-variable losses."""
if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
raise ValueError(
'Passing a weight that does not correspond to any variable '
f'{set(weights.keys())-set(per_variable_losses.keys())}')
weighted_per_variable_losses = {
name: loss * weights.get(name, 1)
for name, loss in per_variable_losses.items()
}
total = xarray.concat(
weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
'variable', skipna=False)
return total, per_variable_losses # pytype: disable=bad-return-type
def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
"""Weights proportional to pressure at each level."""
level = data.coords['level']
return level / level.mean(skipna=False)
def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
"""Weights based on latitude, roughly proportional to grid cell area.
This method supports two use cases only (both for equispaced values):
* Latitude values such that the closest value to the pole is at latitude
(90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
In this case each point with `lat` value represents a sphere slice between
`lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
proportional to:
`sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
that cancels during normalization.
* Latitude values that fall exactly at the poles.
For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
In this case each point with `lat` value also represents
a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
except for the points at the poles, that represent a slice between
`90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
The areas of the first type of point are still proportional to:
* sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
but for the points at the poles now is:
* sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
and we will be using these weights, depending on whether we are looking at
pole cells, or non-pole cells (omitting the common factor of 2 which will be
absorbed by the normalization).
It can be shown via a limit, or simple geometry, that in the small angles
regime, the proportion of area per pole-point is equal to 1/8th
the proportion of area covered by each of the nearest non-pole point, and we
test for this in the test.
Args:
data: `DataArray` with latitude coordinates.
Returns:
Unit mean latitude weights.
"""
latitude = data.coords['lat']
if np.any(np.isclose(np.abs(latitude), 90.)):
weights = _weight_for_latitude_vector_with_poles(latitude)
else:
weights = _weight_for_latitude_vector_without_poles(latitude)
return weights / weights.mean(skipna=False)
def _weight_for_latitude_vector_without_poles(latitude):
"""Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
raise ValueError(
f'Latitude vector {latitude} does not start/end at '
'+- (90 - delta_latitude/2) degrees.')
return np.cos(np.deg2rad(latitude))
def _weight_for_latitude_vector_with_poles(latitude):
"""Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
if (not np.isclose(np.max(latitude), 90.) or
not np.isclose(np.min(latitude), -90.)):
raise ValueError(
f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
# The two checks above enough to guarantee that latitudes are sorted, so
# the extremes are the poles
weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
return weights
def _check_uniform_spacing_and_get_delta(vector):
diff = np.diff(vector)
if not np.all(np.isclose(diff[0], diff)):
raise ValueError(f'Vector {diff} is not uniformly spaced.')
return diff[0]
This diff is collapsed.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers for Predictors which allow them to work with normalized data.
The Predictor which is wrapped sees normalized inputs and targets, and makes
normalized predictions. The wrapper handles translating the predictions back
to the original domain.
"""
import logging
from typing import Optional, Tuple
from graphcast import predictor_base
from graphcast import xarray_tree
import xarray
def normalize(values: xarray.Dataset,
scales: xarray.Dataset,
locations: Optional[xarray.Dataset],
) -> xarray.Dataset:
"""Normalize variables using the given scales and (optionally) locations."""
def normalize_array(array):
if array.name is None:
raise ValueError(
"Can't look up normalization constants because array has no name.")
if locations is not None:
if array.name in locations:
array = array - locations[array.name].astype(array.dtype)
else:
logging.warning('No normalization location found for %s', array.name)
if array.name in scales:
array = array / scales[array.name].astype(array.dtype)
else:
logging.warning('No normalization scale found for %s', array.name)
return array
return xarray_tree.map_structure(normalize_array, values)
def unnormalize(values: xarray.Dataset,
scales: xarray.Dataset,
locations: Optional[xarray.Dataset],
) -> xarray.Dataset:
"""Unnormalize variables using the given scales and (optionally) locations."""
def unnormalize_array(array):
if array.name is None:
raise ValueError(
"Can't look up normalization constants because array has no name.")
if array.name in scales:
array = array * scales[array.name].astype(array.dtype)
else:
logging.warning('No normalization scale found for %s', array.name)
if locations is not None:
if array.name in locations:
array = array + locations[array.name].astype(array.dtype)
else:
logging.warning('No normalization location found for %s', array.name)
return array
return xarray_tree.map_structure(unnormalize_array, values)
class InputsAndResiduals(predictor_base.Predictor):
"""Wraps with a residual connection, normalizing inputs and target residuals.
The inner predictor is given inputs that are normalized using `locations`
and `scales` to roughly zero-mean unit variance.
For target variables that are present in the inputs, the inner predictor is
trained to predict residuals (target - last_frame_of_input) that have been
normalized using `residual_scales` (and optionally `residual_locations`) to
roughly unit variance / zero mean.
This replaces `residual.Predictor` in the case where you want normalization
that's based on the scales of the residuals.
Since we return the underlying predictor's loss on the normalized residuals,
if the underlying predictor is a sum of per-variable losses, the normalization
will affect the relative weighting of the per-variable loss terms (hopefully
in a good way).
For target variables *not* present in the inputs, the inner predictor is
trained to predict targets directly, that have been normalized in the same
way as the inputs.
The transforms applied to the targets (the residual connection and the
normalization) are applied in reverse to the predictions before returning
them.
"""
def __init__(
self,
predictor: predictor_base.Predictor,
stddev_by_level: xarray.Dataset,
mean_by_level: xarray.Dataset,
diffs_stddev_by_level: xarray.Dataset):
self._predictor = predictor
self._scales = stddev_by_level
self._locations = mean_by_level
self._residual_scales = diffs_stddev_by_level
self._residual_locations = None
def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
if norm_prediction.sizes.get('time') != 1:
raise ValueError(
'normalization.InputsAndResiduals only supports predicting a '
'single timestep.')
if norm_prediction.name in inputs:
# Residuals are assumed to be predicted as normalized (unit variance),
# but the scale and location they need mapping to is that of the residuals
# not of the values themselves.
prediction = unnormalize(
norm_prediction, self._residual_scales, self._residual_locations)
# A prediction for which we have a corresponding input -- we are
# predicting the residual:
last_input = inputs[norm_prediction.name].isel(time=-1)
prediction = prediction + last_input
return prediction
else:
# A predicted variable which is not an input variable. We are predicting
# it directly, so unnormalize it directly to the target scale/location:
return unnormalize(norm_prediction, self._scales, self._locations)
def _subtract_input_and_normalize_target(self, inputs, target):
if target.sizes.get('time') != 1:
raise ValueError(
'normalization.InputsAndResiduals only supports wrapping predictors'
'that predict a single timestep.')
if target.name in inputs:
target_residual = target
last_input = inputs[target.name].isel(time=-1)
target_residual = target_residual - last_input
return normalize(
target_residual, self._residual_scales, self._residual_locations)
else:
return normalize(target, self._scales, self._locations)
def __call__(self,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs
) -> xarray.Dataset:
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_predictions = self._predictor(
norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
return xarray_tree.map_structure(
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
norm_predictions)
def loss(self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> predictor_base.LossAndDiagnostics:
"""Returns the loss computed on normalized inputs and targets."""
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_target_residuals = xarray_tree.map_structure(
lambda t: self._subtract_input_and_normalize_target(inputs, t),
targets)
return self._predictor.loss(
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> Tuple[predictor_base.LossAndDiagnostics,
xarray.Dataset]:
"""The loss computed on normalized data, with unnormalized predictions."""
norm_inputs = normalize(inputs, self._scales, self._locations)
norm_forcings = normalize(forcings, self._scales, self._locations)
norm_target_residuals = xarray_tree.map_structure(
lambda t: self._subtract_input_and_normalize_target(inputs, t),
targets)
(loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
predictions = xarray_tree.map_structure(
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
norm_predictions)
return (loss, scalars), predictions
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data-structure for storing graphs with typed edges and nodes."""
from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
_T = TypeVar('_T')
# All tensors have a "flat_batch_axis", which is similar to the leading
# axes of graph_tuples:
# * In the case of nodes this is simply a shared node and flat batch axis, with
# size corresponding to the total number of nodes in the flattened batch.
# * In the case of edges this is simply a shared edge and flat batch axis, with
# size corresponding to the total number of edges in the flattened batch.
# * In the case of globals this is simply the number of graphs in the flattened
# batch.
# All shapes may also have any additional leading shape "batch_shape".
# Options for building batches are:
# * Use a provided "flatten" method that takes a leading `batch_shape` and
# it into the flat_batch_axis (this will be useful when using `tf.Dataset`
# which supports batching into RaggedTensors, with leading batch shape even
# if graphs have different numbers of nodes and edges), so the RaggedBatches
# can then be converted into something without ragged dimensions that jax can
# use.
# * Directly build a "flat batch" using a provided function for batching a list
# of graphs (how it is done in `jraph`).
class NodeSet(NamedTuple):
"""Represents a set of nodes."""
n_node: ArrayLike # [num_flat_graphs]
features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
class EdgesIndices(NamedTuple):
"""Represents indices to nodes adjacent to the edges."""
senders: ArrayLike # [num_flat_edges]
receivers: ArrayLike # [num_flat_edges]
class EdgeSet(NamedTuple):
"""Represents a set of edges."""
n_edge: ArrayLike # [num_flat_graphs]
indices: EdgesIndices
features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
class Context(NamedTuple):
# `n_graph` always contains ones but it is useful to query the leading shape
# in case of graphs without any nodes or edges sets.
n_graph: ArrayLike # [num_flat_graphs]
features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
class EdgeSetKey(NamedTuple):
name: str # Name of the EdgeSet.
# Sender node set name and receiver node set name connected by the edge set.
node_sets: Tuple[str, str]
class TypedGraph(NamedTuple):
"""A graph with typed nodes and edges.
A typed graph is made of a context, multiple sets of nodes and multiple
sets of edges connecting those nodes (as indicated by the EdgeSetKey).
"""
context: Context
nodes: Mapping[str, NodeSet]
edges: Mapping[EdgeSetKey, EdgeSet]
def edge_key_by_name(self, name: str) -> EdgeSetKey:
found_key = [k for k in self.edges.keys() if k.name == name]
if len(found_key) != 1:
raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
name, ', '.join(x.name for x in self.edges.keys())))
return found_key[0]
def edge_by_name(self, name: str) -> EdgeSet:
return self.edges[self.edge_key_by_name(name)]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment