"vscode:/vscode.git/clone" did not exist on "104fcea0c8672b138a9bdd1ae00603c9240867c1"
Commit 4d04d055 authored by mashun1's avatar mashun1
Browse files

graphcast

parents
Pipeline #1048 failed with stages
in 0 seconds
*egg-info
__pycache__
\ No newline at end of file
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
<https://cla.developers.google.com/> to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code Reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
# How to Contribute
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows [Google's Open Source Community
Guidelines](https://opensource.google/conduct/).
FROM image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04-py310
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# graphcast
## 论文
**GraphCast: Learning skillful medium-range global weather forecasting**
* https://arxiv.org/pdf/2212.12794
## 模型结构
网络是包含`编码-处理-解码`的GNN网络。
![alt text](asset/model_structure.png)
## 算法原理
基于GNN的学习模拟器在学习和模拟流体和其他材料的复杂物理动力学方面非常有效,因为它们的表示和计算结构类似于学习的有限元求解器。GNN的一个关键优势是,输入图的结构决定了通过学习消息传递相互作用的表示的哪些部分,允许任意范围的空间交互模式。
![alt text](asset/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04-py310
docker run --shm-size 10g --network=host --name=graphcast --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -e .
pip uninstall shapely
pip install jupyter
pip install shapely
pip install google-cloud
pip install google-cloud-vision
pip install protobuf
pip install --upgrade google-api-python-client
pip install google.cloud.bigquery
pip install google.cloud.storage
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 10g --network=host --name=graphcast --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -e .
pip uninstall shapely
pip install jupyter
pip install shapely
pip install google-cloud
pip install google-cloud-vision
pip install protobuf
pip install --upgrade google-api-python-client
pip install google.cloud.bigquery
pip install google.cloud.storage
### Anaconda(方法三)
DTK驱动:dtk24.04
python:python3.10
jax: 0.4.23
Tips:以上dtk驱动、python、jax等DCU相关工具版本需要严格一一对应
2、其他非特殊库
pip install -e .
pip uninstall shapely
pip install jupyter
pip install shapely
## 数据集
https://console.cloud.google.com/storage/browser/dm_graphcast
注意:该数据集按需下载,在执行`graphcast_demo.ipynb`时,可自动下载示例数据。
## 训练
参考并执行`graphcast_demo.ipynb`中的`Train Model`部分。
## 推理
参考并执行`graphcast_demo.ipynb`中的`Run the model`部分。
## result
![alt text](asset/result.png)
### 精度
## 应用场景
### 算法类别
`天气预报`
### 热点应用行业
`气象,交通,环境`
## 源码仓库及问题反馈
* https://developer.hpccube.com/codes/modelzoo/cfd_jax
## 参考资料
* https://github.com/google-deepmind/graphcast
# GraphCast: Learning skillful medium-range global weather forecasting
This package contains example code to run and train [GraphCast](https://arxiv.org/abs/2212.12794).
It also provides three pretrained models:
1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree
resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,
2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree
resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from
1979 to 2015, useful to run a model with lower memory and compute constraints,
3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13
pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on
HRES data from 2016 to 2021. This model can be initialized from HRES data (does
not require precipitation inputs).
The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).
Full model training requires downloading the
[ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5)
dataset, available from [ECMWF](https://www.ecmwf.int/). This can best be
accessed as Zarr from [Weatherbench2's ERA5 data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5) (see the 6h downsampled versions).
## Overview of files
The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an
example of loading data, generating random weights or load a pre-trained
snapshot, generating predictions, computing the loss and computing gradients.
The one-step implementation of GraphCast architecture, is provided in
`graphcast.py`.
### Brief description of library files:
* `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast
to produce a sequence of predictions by auto-regressively feeding the
outputs back as inputs at each step, in JAX a differentiable way.
* `casting.py`: Wrapper used around GraphCast to make it work using
BFloat16 precision.
* `checkpoint.py`: Utils to serialize and deserialize trees.
* `data_utils.py`: Utils for data preprocessing.
* `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN)
that operates on `TypedGraph`'s where both inputs and outputs are flat
vectors of features for each of the nodes and edges. `graphcast.py` uses
three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid
GNN, respectively.
* `graphcast.py`: The main GraphCast model architecture for one-step of
predictions.
* `grid_mesh_connectivity.py`: Tools for converting between regular grids on a
sphere and triangular meshes.
* `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh.
* `losses.py`: Loss computations, including latitude-weighting.
* `model_utils.py`: Utilities to produce flat node and edge vector features
from input grid data, and to manipulate the node output vectors back
into a multilevel grid data.
* `normalization.py`: Wrapper for the one-step GraphCast used to normalize
inputs according to historical values, and targets according to historical
time differences.
* `predictor_base.py`: Defines the interface of the predictor, which GraphCast
and all of the wrappers implement.
* `rollout.py`: Similar to `autoregressive.py` but used only at inference time
using a python loop to produce longer, but non-differentiable trajectories.
* `solar_radiation.py`: Computes Top-Of-the-Atmosphere (TOA) incident solar
radiation compatible with ERA5. This is used as a forcing variable and thus
needs to be computed for target lead times in an operational setting.
* `typed_graph.py`: Definition of `TypedGraph`'s.
* `typed_graph_net.py`: Implementation of simple graph neural network
building blocks defined over `TypedGraph`'s that can be combined to build
deeper models.
* `xarray_jax.py`: A wrapper to let JAX work with `xarray`s.
* `xarray_tree.py`: An implementation of tree.map_structure that works with
`xarray`s.
### Dependencies.
[Chex](https://github.com/deepmind/chex),
[Dask](https://github.com/dask/dask),
[Haiku](https://github.com/deepmind/dm-haiku),
[JAX](https://github.com/google/jax),
[JAXline](https://github.com/deepmind/jaxline),
[Jraph](https://github.com/deepmind/jraph),
[Numpy](https://numpy.org/),
[Pandas](https://pandas.pydata.org/),
[Python](https://www.python.org/),
[SciPy](https://scipy.org/),
[Tree](https://github.com/deepmind/tree),
[Trimesh](https://github.com/mikedh/trimesh) and
[XArray](https://github.com/pydata/xarray).
### License and attribution
The Colab notebook and the associated code are licensed under the Apache
License, Version 2.0. You may obtain a copy of the License at:
https://www.apache.org/licenses/LICENSE-2.0.
The model weights are made available for use under the terms of the Creative
Commons Attribution-NonCommercial-ShareAlike 4.0 International
(CC BY-NC-SA 4.0). You may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few
examples of ERA5 and HRES data that can be used as inputs to the models.
ECMWF data product are subject to the following terms:
1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)".
2. Source www.ecmwf.int
3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/
4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use.
### Disclaimer
This is not an officially supported Google product.
Copyright 2023 DeepMind Technologies Limited.
### Citation
If you use this work, consider citing our [paper](https://arxiv.org/abs/2212.12794):
```latex
@article{lam2022graphcast,
title={GraphCast: Learning skillful medium-range global weather forecasting},
author={Remi Lam and Alvaro Sanchez-Gonzalez and Matthew Willson and Peter Wirnsberger and Meire Fortunato and Alexander Pritzel and Suman Ravuri and Timo Ewalds and Ferran Alet and Zach Eaton-Rosen and Weihua Hu and Alexander Merose and Stephan Hoyer and George Holland and Jacklynn Stott and Oriol Vinyals and Shakir Mohamed and Peter Battaglia},
year={2022},
eprint={2212.12794},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Predictor wrapping a one-step Predictor to make autoregressive predictions.
"""
from typing import Optional, cast
from absl import logging
from graphcast import predictor_base
from graphcast import xarray_jax
from graphcast import xarray_tree
import haiku as hk
import jax
import xarray
def _unflatten_and_expand_time(flat_variables, tree_def, time_coords):
variables = jax.tree_util.tree_unflatten(tree_def, flat_variables)
return variables.expand_dims(time=time_coords, axis=0)
def _get_flat_arrays_and_single_timestep_treedef(variables):
flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...))
_, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True))
return flat_arrays, treedef
class Predictor(predictor_base.Predictor):
"""Wraps a one-step Predictor to make multi-step predictions autoregressively.
The wrapped Predictor will be used to predict a single timestep conditional
on the inputs passed to the outer Predictor. Its predictions are then
passed back in as inputs at the next timestep, for as many timesteps as are
requested in the targets_template. (When multiple timesteps of input are
used, a rolling window of inputs is maintained with new predictions
concatenated onto the end).
You may ask for additional variables to be predicted as targets which aren't
used as inputs. These will be predicted as output variables only and not fed
back in autoregressively. All target variables must be time-dependent however.
You may also specify static (non-time-dependent) inputs which will be passed
in at each timestep but are not predicted.
At present, any time-dependent inputs must also be present as targets so they
can be passed in autoregressively.
The loss of the wrapped one-step Predictor is averaged over all timesteps to
give a loss for the autoregressive Predictor.
"""
def __init__(
self,
predictor: predictor_base.Predictor,
noise_level: Optional[float] = None,
gradient_checkpointing: bool = False,
):
"""Initializes an autoregressive predictor wrapper.
Args:
predictor: A predictor to wrap in an auto-regressive way.
noise_level: Optional value that multiplies the standard normal noise
added to the time-dependent variables of the predictor inputs. In
particular, no noise is added to the predictions that are fed back
auto-regressively. Defaults to not adding noise.
gradient_checkpointing: If True, gradient checkpointing will be
used at each step of the computation to save on memory. Roughtly this
should make the backwards pass two times more expensive, and the time
per step counting the forward pass, should only increase by about 50%.
Note this parameter will be ignored with a warning if the scan sequence
length is 1.
"""
self._predictor = predictor
self._noise_level = noise_level
self._gradient_checkpointing = gradient_checkpointing
def _get_and_validate_constant_inputs(self, inputs, targets, forcings):
constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore')
constant_inputs = constant_inputs.drop_vars(
forcings.keys(), errors='ignore')
for name, var in constant_inputs.items():
if 'time' in var.dims:
raise ValueError(
f'Time-dependent input variable {name} must either be a forcing '
'variable, or a target variable to allow for auto-regressive '
'feedback.')
return constant_inputs
def _validate_targets_and_forcings(self, targets, forcings):
for name, var in targets.items():
if 'time' not in var.dims:
raise ValueError(f'Target variable {name} must be time-dependent.')
for name, var in forcings.items():
if 'time' not in var.dims:
raise ValueError(f'Forcing variable {name} must be time-dependent.')
overlap = forcings.keys() & targets.keys()
if overlap:
raise ValueError('The following were specified as both targets and '
f'forcings, which isn\'t allowed: {overlap}')
def _update_inputs(self, inputs, next_frame):
num_inputs = inputs.dims['time']
predicted_or_forced_inputs = next_frame[list(inputs.keys())]
# Combining datasets with inputs and target time stamps aligns them.
# Only keep the num_inputs trailing frames for use as next inputs.
return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time')
.tail(time=num_inputs)
# Update the time coordinate to reset the lead times for
# next AR iteration.
.assign_coords(time=inputs.coords['time']))
def __call__(self,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs) -> xarray.Dataset:
"""Calls the Predictor.
Args:
inputs: input variable used to make predictions. Inputs can include both
time-dependent and time independent variables. Any time-dependent
input variables must also be present in the targets_template or the
forcings.
targets_template: A target template containing informations about which
variables should be predicted and the time alignment of the predictions.
All target variables must be time-dependent.
The number of time frames is used to set the number of unroll of the AR
predictor (e.g. multiple unroll of the inner predictor for one time step
in the targets is not supported yet).
forcings: Variables that will be fed to the model. The variables
should not overlap with the target ones. The time coordinates of the
forcing variables should match the target ones.
Forcing variables which are also present in the inputs, will be used to
supply ground-truth values for those inputs when they are passed to the
underlying predictor at timesteps beyond the first timestep.
**kwargs: Additional arguments passed along to the inner Predictor.
Returns:
predictions: the model predictions matching the target template.
Raise:
ValueError: if the time coordinates of the inputs and targets are not
different by a constant time step.
"""
constant_inputs = self._get_and_validate_constant_inputs(
inputs, targets_template, forcings)
self._validate_targets_and_forcings(targets_template, forcings)
# After the above checks, the remaining inputs must be time-dependent:
inputs = inputs.drop_vars(constant_inputs.keys())
# A predictions template only including the next time to predict.
target_template = targets_template.isel(time=[0])
flat_forcings, forcings_treedef = (
_get_flat_arrays_and_single_timestep_treedef(forcings))
scan_variables = flat_forcings
def one_step_prediction(inputs, scan_variables):
flat_forcings = scan_variables
forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
target_template.coords['time'])
# Add constant inputs:
all_inputs = xarray.merge([constant_inputs, inputs])
predictions: xarray.Dataset = self._predictor(
all_inputs, target_template,
forcings=forcings,
**kwargs)
next_frame = xarray.merge([predictions, forcings])
next_inputs = self._update_inputs(inputs, next_frame)
# Drop the length-1 time dimension, since scan will concat all the outputs
# for different times along a new leading time dimension:
predictions = predictions.squeeze('time', drop=True)
# We return the prediction flattened into plain jax arrays, because the
# extra leading dimension added by scan prevents the tree_util
# registrations in xarray_jax from unflattening them back into an
# xarray.Dataset automatically:
flat_pred = jax.tree_util.tree_leaves(predictions)
return next_inputs, flat_pred
if self._gradient_checkpointing:
scan_length = targets_template.dims['time']
if scan_length <= 1:
logging.warning(
'Skipping gradient checkpointing for sequence length of 1')
else:
# Just in case we take gradients (e.g. for control), although
# in most cases this will just be for a forward pass.
one_step_prediction = hk.remat(one_step_prediction)
# Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
_, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
# The result of scan will have an extra leading axis on all arrays,
# corresponding to the target times in this case. We need to be prepared for
# it when unflattening the arrays back into a Dataset:
scan_result_template = (
target_template.squeeze('time', drop=True)
.expand_dims(time=targets_template.coords['time'], axis=0))
_, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template)
predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds)
return predictions
def loss(self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs
) -> predictor_base.LossAndDiagnostics:
"""The mean of the per-timestep losses of the underlying predictor."""
if targets.sizes['time'] == 1:
# If there is only a single target timestep then we don't need any
# autoregressive feedback and can delegate the loss directly to the
# underlying single-step predictor. This means the underlying predictor
# doesn't need to implement .loss_and_predictions.
return self._predictor.loss(inputs, targets, forcings, **kwargs)
constant_inputs = self._get_and_validate_constant_inputs(
inputs, targets, forcings)
self._validate_targets_and_forcings(targets, forcings)
# After the above checks, the remaining inputs must be time-dependent:
inputs = inputs.drop_vars(constant_inputs.keys())
if self._noise_level:
def add_noise(x):
return x + self._noise_level * jax.random.normal(
hk.next_rng_key(), shape=x.shape)
# Add noise to time-dependent variables of the inputs.
inputs = jax.tree_map(add_noise, inputs)
# The per-timestep targets passed by scan to one_step_loss below will have
# no leading time axis. We need a treedef without the time axis to use
# inside one_step_loss to unflatten it back into a dataset:
flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef(
targets)
scan_variables = flat_targets
flat_forcings, forcings_treedef = (
_get_flat_arrays_and_single_timestep_treedef(forcings))
scan_variables = (flat_targets, flat_forcings)
def one_step_loss(inputs, scan_variables):
flat_target, flat_forcings = scan_variables
forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
targets.coords['time'][:1])
target = _unflatten_and_expand_time(flat_target, target_treedef,
targets.coords['time'][:1])
# Add constant inputs:
all_inputs = xarray.merge([constant_inputs, inputs])
(loss, diagnostics), predictions = self._predictor.loss_and_predictions(
all_inputs,
target,
forcings=forcings,
**kwargs)
# Unwrap to jax arrays shape (batch,):
loss, diagnostics = xarray_tree.map_structure(
xarray_jax.unwrap_data, (loss, diagnostics))
predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy.
next_frame = xarray.merge([predictions, forcings])
next_inputs = self._update_inputs(inputs, next_frame)
return next_inputs, (loss, diagnostics)
if self._gradient_checkpointing:
scan_length = targets.dims['time']
if scan_length <= 1:
logging.warning(
'Skipping gradient checkpointing for sequence length of 1')
else:
one_step_loss = hk.remat(one_step_loss)
# We can pass inputs (the initial state of the loop) in directly as a
# Dataset because the shape we pass in to scan is the same as the shape scan
# passes to the inner function. But, for scan_variables, we must flatten the
# targets (and unflatten them inside the inner function) because they are
# passed to the inner function per-timestep without the original time axis.
# The same apply to the optional forcing.
_, (per_timestep_losses, per_timestep_diagnostics) = hk.scan(
one_step_loss, inputs, scan_variables)
# Re-wrap loss and diagnostics as DataArray and average them over time:
(loss, diagnostics) = jax.tree_util.tree_map(
lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda
'time', skipna=False),
(per_timestep_losses, per_timestep_diagnostics))
return loss, diagnostics
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers that take care of casting."""
import contextlib
from typing import Any, Mapping, Tuple
import chex
from graphcast import predictor_base
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import xarray
PyTree = Any
class Bfloat16Cast(predictor_base.Predictor):
"""Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
"""Inits the wrapper.
Args:
predictor: predictor being wrapped.
enabled: disables the wrapper if False, for simpler hyperparameter scans.
"""
self._enabled = enabled
self._predictor = predictor
def __call__(self,
inputs: xarray.Dataset,
targets_template: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs
) -> xarray.Dataset:
if not self._enabled:
return self._predictor(inputs, targets_template, forcings, **kwargs)
with bfloat16_variable_view():
predictions = self._predictor(
*_all_inputs_to_bfloat16(inputs, targets_template, forcings),
**kwargs,)
predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
if predictions_dtype != jnp.bfloat16:
raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
targets_dtype = infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types
return tree_map_cast(
predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
def loss(self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> predictor_base.LossAndDiagnostics:
if not self._enabled:
return self._predictor.loss(inputs, targets, forcings, **kwargs)
with bfloat16_variable_view():
loss, scalars = self._predictor.loss(
*_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
if loss.dtype != jnp.bfloat16:
raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
# Note that casting back the loss to e.g. float32 should not affect data
# types of the backwards pass, because the first thing the backwards pass
# should do is to go backwards the casting op and cast back to bfloat16
# (and xprofs seem to confirm this).
return tree_map_cast((loss, scalars),
input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
self,
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
**kwargs,
) -> Tuple[predictor_base.LossAndDiagnostics,
xarray.Dataset]:
if not self._enabled:
return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
**kwargs)
with bfloat16_variable_view():
(loss, scalars), predictions = self._predictor.loss_and_predictions(
*_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
if loss.dtype != jnp.bfloat16:
raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
if predictions_dtype != jnp.bfloat16:
raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
return tree_map_cast(((loss, scalars), predictions),
input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
"""Infers a floating dtype from an input mapping of data."""
dtypes = {
v.dtype
for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
if len(dtypes) != 1:
dtypes_and_shapes = {
k: (v.dtype, v.shape)
for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
raise ValueError(
f'Did not found exactly one floating dtype {dtypes} in input variables:'
f'{dtypes_and_shapes}')
return list(dtypes)[0]
def _all_inputs_to_bfloat16(
inputs: xarray.Dataset,
targets: xarray.Dataset,
forcings: xarray.Dataset,
) -> Tuple[xarray.Dataset,
xarray.Dataset,
xarray.Dataset]:
return (inputs.astype(jnp.bfloat16),
jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
forcings.astype(jnp.bfloat16))
def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
) -> PyTree:
def cast_fn(x):
if x.dtype == input_dtype:
return x.astype(output_dtype)
return jax.tree_map(cast_fn, inputs)
@contextlib.contextmanager
def bfloat16_variable_view(enabled: bool = True):
"""Context for Haiku modules with float32 params, but bfloat16 activations.
It works as follows:
* Every time a variable is requested to be created/set as np.bfloat16,
it will create an underlying float32 variable, instead.
* Every time a variable a variable is requested as bfloat16, it will check the
variable is of float32 type, and cast the variable to bfloat16.
Note the gradients are still computed and accumulated as float32, because
the params returned by init are float32, so the gradient function with
respect to the params will already include an implicit casting to float32.
Args:
enabled: Only enables bfloat16 behavior if True.
Yields:
None
"""
if enabled:
with hk.custom_creator(
_bfloat16_creator, state=True), hk.custom_getter(
_bfloat16_getter, state=True), hk.custom_setter(
_bfloat16_setter):
yield
else:
yield
def _bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def _bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
def _bfloat16_setter(next_setter, value, context):
"""Casts bfloat16 to float32 when bfloat16 was originally set."""
if context.original_dtype == jnp.bfloat16:
value = value.astype(jnp.float32)
return next_setter(value)
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serialize and deserialize trees."""
import dataclasses
import io
import types
from typing import Any, BinaryIO, Optional, TypeVar
import numpy as np
_T = TypeVar("_T")
def dump(dest: BinaryIO, value: Any) -> None:
"""Dump a tree of dicts/dataclasses to a file object.
Args:
dest: a file object to write to.
value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
other basic types. Unions are not supported, other than Optional/None
which is only supported in dataclasses, not in dicts, lists or tuples.
All leaves must be coercible to a numpy array, and recoverable as a single
arg to a type.
"""
buffer = io.BytesIO() # In case the destination doesn't support seeking.
np.savez(buffer, **_flatten(value))
dest.write(buffer.getvalue())
def load(source: BinaryIO, typ: type[_T]) -> _T:
"""Load from a file object and convert it to the specified type.
Args:
source: a file object to read from.
typ: a type object that acts as a schema for deserialization. It must match
what was serialized. If a type is Any, it will be returned however numpy
serialized it, which is what you want for a tree of numpy arrays.
Returns:
the deserialized value as the specified type.
"""
return _convert_types(typ, _unflatten(np.load(source)))
_SEP = ":"
def _flatten(tree: Any) -> dict[str, Any]:
"""Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
if dataclasses.is_dataclass(tree):
# Don't use dataclasses.asdict as it is recursive so skips dropping None.
tree = {f.name: v for f in dataclasses.fields(tree)
if (v := getattr(tree, f.name)) is not None}
elif isinstance(tree, (list, tuple)):
tree = dict(enumerate(tree))
assert isinstance(tree, dict)
flat = {}
for k, v in tree.items():
k = str(k)
assert _SEP not in k
if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
for a, b in _flatten(v).items():
flat[f"{k}{_SEP}{a}"] = b
else:
assert v is not None
flat[k] = v
return flat
def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
"""Unflatten a dict to a tree of dicts."""
tree = {}
for flat_key, v in flat.items():
node = tree
keys = flat_key.split(_SEP)
for k in keys[:-1]:
if k not in node:
node[k] = {}
node = node[k]
node[keys[-1]] = v
return tree
def _convert_types(typ: type[_T], value: Any) -> _T:
"""Convert some structure into the given type. The structures must match."""
if typ in (Any, ...):
return value
if typ in (int, float, str, bool):
return typ(value)
if typ is np.ndarray:
assert isinstance(value, np.ndarray)
return value
if dataclasses.is_dataclass(typ):
kwargs = {}
for f in dataclasses.fields(typ):
# Only support Optional for dataclasses, as numpy can't serialize it
# directly (without pickle), and dataclasses are the only case where we
# can know the full set of values and types and therefore know the
# non-existence must mean None.
if isinstance(f.type, (types.UnionType, type(Optional[int]))):
constructors = [t for t in f.type.__args__ if t is not types.NoneType]
if len(constructors) != 1:
raise TypeError(
"Optional works, Union with anything except None doesn't")
if f.name not in value:
kwargs[f.name] = None
continue
constructor = constructors[0]
else:
constructor = f.type
if f.name in value:
kwargs[f.name] = _convert_types(constructor, value[f.name])
else:
raise ValueError(f"Missing value: {f.name}")
return typ(**kwargs)
base_type = getattr(typ, "__origin__", None)
if base_type is dict:
assert len(typ.__args__) == 2
key_type, value_type = typ.__args__
return {_convert_types(key_type, k): _convert_types(value_type, v)
for k, v in value.items()}
if base_type is list:
assert len(typ.__args__) == 1
value_type = typ.__args__[0]
return [_convert_types(value_type, v)
for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
if base_type is tuple:
if len(typ.__args__) == 2 and typ.__args__[1] == ...:
# An arbitrary length tuple of a single type, eg: tuple[int, ...]
value_type = typ.__args__[0]
return tuple(_convert_types(value_type, v)
for _, v in sorted(value.items(), key=lambda x: int(x[0])))
else:
# A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
assert len(typ.__args__) == len(value)
return tuple(
_convert_types(t, v)
for t, (_, v) in zip(
typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
# This is probably unreachable with reasonable serializable inputs.
try:
return typ(value)
except TypeError as e:
raise TypeError(
"_convert_types expects the type argument to be a dataclass defined "
"with types that are valid constructors (eg tuple is fine, Tuple "
"isn't), and accept a numpy array as the sole argument.") from e
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check that the checkpoint serialization is reversable."""
import dataclasses
import io
from typing import Any, Optional, Union
from absl.testing import absltest
from graphcast import checkpoint
import numpy as np
@dataclasses.dataclass
class SubConfig:
a: int
b: str
@dataclasses.dataclass
class Config:
bt: bool
bf: bool
i: int
f: float
o1: Optional[int]
o2: Optional[int]
o3: Union[int, None]
o4: Union[int, None]
o5: int | None
o6: int | None
li: list[int]
ls: list[str]
ldc: list[SubConfig]
tf: tuple[float, ...]
ts: tuple[str, ...]
t: tuple[str, int, SubConfig]
tdc: tuple[SubConfig, ...]
dsi: dict[str, int]
dss: dict[str, str]
dis: dict[int, str]
dsdis: dict[str, dict[int, str]]
dc: SubConfig
dco: Optional[SubConfig]
ddc: dict[str, SubConfig]
@dataclasses.dataclass
class Checkpoint:
params: dict[str, Any]
config: Config
class DataclassTest(absltest.TestCase):
def test_serialize_dataclass(self):
ckpt = Checkpoint(
params={
"layer1": {
"w": np.arange(10).reshape(2, 5),
"b": np.array([2, 6]),
},
"layer2": {
"w": np.arange(8).reshape(2, 4),
"b": np.array([2, 6]),
},
"blah": np.array([3, 9]),
},
config=Config(
bt=True,
bf=False,
i=42,
f=3.14,
o1=1,
o2=None,
o3=2,
o4=None,
o5=3,
o6=None,
li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2],
ls=list("qhjfdxtpzgemryoikwvblcaus"),
ldc=[SubConfig(1, "hello"), SubConfig(2, "world")],
tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6),
ts=("hello", "world"),
t=("foo", 42, SubConfig(1, "bar")),
tdc=(SubConfig(1, "hello"), SubConfig(2, "world")),
dsi={"a": 1, "b": 2, "c": 3},
dss={"d": "e", "f": "g"},
dis={1: "a", 2: "b", 3: "c"},
dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}},
dc=SubConfig(1, "hello"),
dco=None,
ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")},
))
buffer = io.BytesIO()
checkpoint.dump(buffer, ckpt)
buffer.seek(0)
ckpt2 = checkpoint.load(buffer, Checkpoint)
np.testing.assert_array_equal(ckpt.params["layer1"]["w"],
ckpt2.params["layer1"]["w"])
np.testing.assert_array_equal(ckpt.params["layer1"]["b"],
ckpt2.params["layer1"]["b"])
np.testing.assert_array_equal(ckpt.params["layer2"]["w"],
ckpt2.params["layer2"]["w"])
np.testing.assert_array_equal(ckpt.params["layer2"]["b"],
ckpt2.params["layer2"]["b"])
np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"])
self.assertEqual(ckpt.config, ckpt2.config)
if __name__ == "__main__":
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset utilities."""
from typing import Any, Mapping, Sequence, Tuple, Union
from graphcast import solar_radiation
import numpy as np
import pandas as pd
import xarray
TimedeltaLike = Any # Something convertible to pd.Timedelta.
TimedeltaStr = str # A string convertible to pd.Timedelta.
TargetLeadTimes = Union[
TimedeltaLike,
Sequence[TimedeltaLike],
slice # with TimedeltaLike as its start and stop.
]
_SEC_PER_HOUR = 3600
_HOUR_PER_DAY = 24
SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY
_AVG_DAY_PER_YEAR = 365.24219
AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR
DAY_PROGRESS = "day_progress"
YEAR_PROGRESS = "year_progress"
_DERIVED_VARS = {
DAY_PROGRESS,
f"{DAY_PROGRESS}_sin",
f"{DAY_PROGRESS}_cos",
YEAR_PROGRESS,
f"{YEAR_PROGRESS}_sin",
f"{YEAR_PROGRESS}_cos",
}
TISR = "toa_incident_solar_radiation"
def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
"""Computes year progress for times in seconds.
Args:
seconds_since_epoch: Times in seconds since the "epoch" (the point at which
UNIX time starts).
Returns:
Year progress normalized to be in the [0, 1) interval for each time point.
"""
# Start with the pure integer division, and then float at the very end.
# We will try to keep as much precision as possible.
years_since_epoch = (
seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR)
)
# Note depending on how these ops are down, we may end up with a "weak_type"
# which can cause issues in subtle ways, and hard to track here.
# In any case, casting to float32 should get rid of the weak type.
# [0, 1.) Interval.
return np.mod(years_since_epoch, 1.0).astype(np.float32)
def get_day_progress(
seconds_since_epoch: np.ndarray,
longitude: np.ndarray,
) -> np.ndarray:
"""Computes day progress for times in seconds at each longitude.
Args:
seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the
point at which UNIX time starts).
longitude: 1D array of longitudes at which day progress is computed.
Returns:
2D array of day progress values normalized to be in the [0, 1) inverval
for each time point at each longitude.
"""
# [0.0, 1.0) Interval.
day_progress_greenwich = (
np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY
)
# Offset the day progress to the longitude of each point on Earth.
longitude_offsets = np.deg2rad(longitude) / (2 * np.pi)
day_progress = np.mod(
day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0
)
return day_progress.astype(np.float32)
def featurize_progress(
name: str, dims: Sequence[str], progress: np.ndarray
) -> Mapping[str, xarray.Variable]:
"""Derives features used by ML models from the `progress` variable.
Args:
name: Base variable name from which features are derived.
dims: List of the output feature dimensions, e.g. ("day", "lon").
progress: Progress variable values.
Returns:
Dictionary of xarray variables derived from the `progress` values. It
includes the original `progress` variable along with its sin and cos
transformations.
Raises:
ValueError if the number of feature dimensions is not equal to the number
of data dimensions.
"""
if len(dims) != progress.ndim:
raise ValueError(
f"Number of feature dimensions ({len(dims)}) must be equal to the"
f" number of data dimensions: {progress.ndim}."
)
progress_phase = progress * (2 * np.pi)
return {
name: xarray.Variable(dims, progress),
name + "_sin": xarray.Variable(dims, np.sin(progress_phase)),
name + "_cos": xarray.Variable(dims, np.cos(progress_phase)),
}
def add_derived_vars(data: xarray.Dataset) -> None:
"""Adds year and day progress features to `data` in place if missing.
Args:
data: Xarray dataset to which derived features will be added.
Raises:
ValueError if `datetime` or `lon` are not in `data` coordinates.
"""
for coord in ("datetime", "lon"):
if coord not in data.coords:
raise ValueError(f"'{coord}' must be in `data` coordinates.")
# Compute seconds since epoch.
# Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)`
# does not work as xarrays always cast dates into nanoseconds!
seconds_since_epoch = (
data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64)
)
batch_dim = ("batch",) if "batch" in data.dims else ()
# Add year progress features if missing.
if YEAR_PROGRESS not in data.data_vars:
year_progress = get_year_progress(seconds_since_epoch)
data.update(
featurize_progress(
name=YEAR_PROGRESS,
dims=batch_dim + ("time",),
progress=year_progress,
)
)
# Add day progress features if missing.
if DAY_PROGRESS not in data.data_vars:
longitude_coord = data.coords["lon"]
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
data.update(
featurize_progress(
name=DAY_PROGRESS,
dims=batch_dim + ("time",) + longitude_coord.dims,
progress=day_progress,
)
)
def add_tisr_var(data: xarray.Dataset) -> None:
"""Adds TISR feature to `data` in place if missing.
Args:
data: Xarray dataset to which TISR feature will be added.
Raises:
ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates.
"""
if TISR in data.data_vars:
return
for coord in ("datetime", "lat", "lon"):
if coord not in data.coords:
raise ValueError(f"'{coord}' must be in `data` coordinates.")
# Remove `batch` dimension of size one if present. An error will be raised if
# the `batch` dimension exists and has size greater than one.
data_no_batch = data.squeeze("batch") if "batch" in data.dims else data
tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
data_no_batch, use_jit=True
)
if "batch" in data.dims:
tisr = tisr.expand_dims("batch", axis=0)
data.update({TISR: tisr})
def extract_input_target_times(
dataset: xarray.Dataset,
input_duration: TimedeltaLike,
target_lead_times: TargetLeadTimes,
) -> Tuple[xarray.Dataset, xarray.Dataset]:
"""Extracts inputs and targets for prediction, from a Dataset with a time dim.
The input period is assumed to be contiguous (specified by a duration), but
the targets can be a list of arbitrary lead times.
Examples:
# Use 18 hours of data as inputs, and two specific lead times as targets:
# 3 days and 5 days after the final input.
extract_inputs_targets(
dataset,
input_duration='18h',
target_lead_times=('3d', '5d')
)
# Use 1 day of data as input, and all lead times between 6 hours and
# 24 hours inclusive as targets. Demonstrates a friendlier supported string
# syntax.
extract_inputs_targets(
dataset,
input_duration='1 day',
target_lead_times=slice('6 hours', '24 hours')
)
# Just use a single target lead time of 3 days:
extract_inputs_targets(
dataset,
input_duration='24h',
target_lead_times='3d'
)
Args:
dataset: An xarray.Dataset with a 'time' dimension whose coordinates are
timedeltas. It's assumed that the time coordinates have a fixed offset /
time resolution, and that the input_duration and target_lead_times are
multiples of this.
input_duration: pandas.Timedelta or something convertible to it (e.g. a
shorthand string like '6h' or '5d12h').
target_lead_times: Either a single lead time, a slice with start and stop
(inclusive) lead times, or a sequence of lead times. Lead times should be
Timedeltas (or something convertible to). They are given relative to the
final input timestep, and should be positive.
Returns:
inputs:
targets:
Two datasets with the same shape as the input dataset except that a
selection has been made from the time axis, and the origin of the
time coordinate will be shifted to refer to lead times relative to the
final input timestep. So for inputs the times will end at lead time 0,
for targets the time coordinates will refer to the lead times requested.
"""
(target_lead_times, target_duration
) = _process_target_lead_times_and_get_duration(target_lead_times)
# Shift the coordinates for the time axis so that a timedelta of zero
# corresponds to the forecast reference time. That is, the final timestep
# that's available as input to the forecast, with all following timesteps
# forming the target period which needs to be predicted.
# This means the time coordinates are now forecast lead times.
time = dataset.coords["time"]
dataset = dataset.assign_coords(time=time + target_duration - time[-1])
# Slice out targets:
targets = dataset.sel({"time": target_lead_times})
input_duration = pd.Timedelta(input_duration)
# Both endpoints are inclusive with label-based slicing, so we offset by a
# small epsilon to make one of the endpoints non-inclusive:
zero = pd.Timedelta(0)
epsilon = pd.Timedelta(1, "ns")
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
return inputs, targets
def _process_target_lead_times_and_get_duration(
target_lead_times: TargetLeadTimes) -> TimedeltaLike:
"""Returns the minimum duration for the target lead times."""
if isinstance(target_lead_times, slice):
# A slice of lead times. xarray already accepts timedelta-like values for
# the begin/end/step of the slice.
if target_lead_times.start is None:
# If the start isn't specified, we assume it starts at the next timestep
# after lead time 0 (lead time 0 is the final input timestep):
target_lead_times = slice(
pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step
)
target_duration = pd.Timedelta(target_lead_times.stop)
else:
if not isinstance(target_lead_times, (list, tuple, set)):
# A single lead time, which we wrap as a length-1 array to ensure there
# still remains a time dimension (here of length 1) for consistency.
target_lead_times = [target_lead_times]
# A list of multiple (not necessarily contiguous) lead times:
target_lead_times = [pd.Timedelta(x) for x in target_lead_times]
target_lead_times.sort()
target_duration = target_lead_times[-1]
return target_lead_times, target_duration
def extract_inputs_targets_forcings(
dataset: xarray.Dataset,
*,
input_variables: Tuple[str, ...],
target_variables: Tuple[str, ...],
forcing_variables: Tuple[str, ...],
pressure_levels: Tuple[int, ...],
input_duration: TimedeltaLike,
target_lead_times: TargetLeadTimes,
) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]:
"""Extracts inputs, targets and forcings according to requirements."""
dataset = dataset.sel(level=list(pressure_levels))
# "Forcings" include derived variables that do not exist in the original ERA5
# or HRES datasets, as well as other variables (e.g. tisr) that need to be
# computed manually for the target lead times. Compute the requested ones.
if set(forcing_variables) & _DERIVED_VARS:
add_derived_vars(dataset)
if set(forcing_variables) & {TISR}:
add_tisr_var(dataset)
# `datetime` is needed by add_derived_vars but breaks autoregressive rollouts.
dataset = dataset.drop_vars("datetime")
inputs, targets = extract_input_target_times(
dataset,
input_duration=input_duration,
target_lead_times=target_lead_times)
if set(forcing_variables) & set(target_variables):
raise ValueError(
f"Forcing variables {forcing_variables} should not "
f"overlap with target variables {target_variables}."
)
inputs = inputs[list(input_variables)]
# The forcing uses the same time coordinates as the target.
forcings = targets[list(forcing_variables)]
targets = targets[list(target_variables)]
return inputs, targets, forcings
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `data_utils.py`."""
import datetime
from absl.testing import absltest
from absl.testing import parameterized
from graphcast import data_utils
import numpy as np
import xarray as xa
class DataUtilsTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Fix the seed for reproducibility.
np.random.seed(0)
def test_year_progress_is_zero_at_year_start_or_end(self):
year_progress = data_utils.get_year_progress(
np.array([
0,
data_utils.AVG_SEC_PER_YEAR,
data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
])
)
np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
def test_year_progress_is_almost_one_before_year_ends(self):
year_progress = data_utils.get_year_progress(
np.array([
data_utils.AVG_SEC_PER_YEAR - 1,
(data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
])
)
with self.subTest("Year progress values are close to 1"):
self.assertTrue(np.all(year_progress > 0.999))
with self.subTest("Year progress values != 1"):
self.assertTrue(np.all(year_progress < 1.0))
def test_day_progress_computes_for_all_times_and_longitudes(self):
times = np.random.randint(low=0, high=1e10, size=10)
longitudes = np.arange(0, 360.0, 1.0)
day_progress = data_utils.get_day_progress(times, longitudes)
with self.subTest("Day progress is computed for all times and longinutes"):
self.assertSequenceEqual(
day_progress.shape, (len(times), len(longitudes))
)
@parameterized.named_parameters(
dict(
testcase_name="random_date_1",
year=1988,
month=11,
day=7,
hour=2,
minute=45,
second=34,
),
dict(
testcase_name="random_date_2",
year=2022,
month=3,
day=12,
hour=7,
minute=1,
second=0,
),
)
def test_day_progress_is_in_between_zero_and_one(
self, year, month, day, hour, minute, second
):
# Datetime from a timestamp.
dt = datetime.datetime(year, month, day, hour, minute, second)
# Epoch time.
epoch_time = datetime.datetime(1970, 1, 1)
# Seconds since epoch.
seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
# Longitudes with 1 degree resolution.
longitudes = np.arange(0, 360.0, 1.0)
day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
with self.subTest("Day progress >= 0"):
self.assertTrue(np.all(day_progress >= 0.0))
with self.subTest("Day progress < 1"):
self.assertTrue(np.all(day_progress < 1.0))
def test_day_progress_is_zero_at_day_start_or_end(self):
day_progress = data_utils.get_day_progress(
seconds_since_epoch=np.array([
0,
data_utils.SEC_PER_DAY,
data_utils.SEC_PER_DAY * 42, # 42 days.
]),
longitude=np.array([0.0]),
)
np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
def test_day_progress_specific_value(self):
day_progress = data_utils.get_day_progress(
seconds_since_epoch=np.array([123]),
longitude=np.array([0.0]),
)
np.testing.assert_array_almost_equal(
day_progress, np.array([[0.00142361]]), decimal=6
)
def test_featurize_progress_valid_values_and_dimensions(self):
day_progress = np.array([0.0, 0.45, 0.213])
feature_dimensions = ("time",)
progress_features = data_utils.featurize_progress(
name="day_progress", dims=feature_dimensions, progress=day_progress
)
for feature in progress_features.values():
with self.subTest(f"Valid dimensions for {feature}"):
self.assertSequenceEqual(feature.dims, feature_dimensions)
with self.subTest("Valid values for day_progress"):
np.testing.assert_array_equal(
day_progress, progress_features["day_progress"].values
)
with self.subTest("Valid values for day_progress_sin"):
np.testing.assert_array_almost_equal(
np.array([0.0, 0.30901699, 0.97309851]),
progress_features["day_progress_sin"].values,
decimal=6,
)
with self.subTest("Valid values for day_progress_cos"):
np.testing.assert_array_almost_equal(
np.array([1.0, -0.95105652, 0.23038943]),
progress_features["day_progress_cos"].values,
decimal=6,
)
def test_featurize_progress_invalid_dimensions(self):
year_progress = np.array([0.0, 0.45, 0.213])
feature_dimensions = ("time", "longitude")
with self.assertRaises(ValueError):
data_utils.featurize_progress(
name="year_progress", dims=feature_dimensions, progress=year_progress
)
def test_add_derived_vars_variables_added(self):
data = xa.Dataset(
data_vars={
"var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
},
coords={
"lon": np.array([0.0, 0.5]),
"datetime": np.array([
datetime.datetime(2021, 1, 1),
datetime.datetime(2023, 1, 1),
datetime.datetime(2023, 1, 3),
]),
},
)
data_utils.add_derived_vars(data)
all_variables = set(data.variables)
with self.subTest("Original value was not removed"):
self.assertIn("var1", all_variables)
with self.subTest("Year progress feature was added"):
self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
with self.subTest("Day progress feature was added"):
self.assertIn(data_utils.DAY_PROGRESS, all_variables)
def test_add_derived_vars_existing_vars_not_overridden(self):
dims = ["x", "lon", "datetime"]
data = xa.Dataset(
data_vars={
"var1": (dims, 8 * np.random.randn(2, 2, 3)),
data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
},
coords={
"lon": np.array([0.0, 0.5]),
"datetime": np.array([
datetime.datetime(2021, 1, 1),
datetime.datetime(2023, 1, 1),
datetime.datetime(2023, 1, 3),
]),
},
)
data_utils.add_derived_vars(data)
with self.subTest("Year progress feature was not overridden"):
np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
with self.subTest("Day progress feature was not overridden"):
np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)
@parameterized.named_parameters(
dict(testcase_name="missing_datetime", coord_name="lon"),
dict(testcase_name="missing_lon", coord_name="datetime"),
)
def test_add_derived_vars_missing_coordinate_raises_value_error(
self, coord_name
):
with self.subTest(f"Missing {coord_name} coordinate"):
data = xa.Dataset(
data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
coords={
coord_name: np.array([0.0, 0.5]),
},
)
with self.assertRaises(ValueError):
data_utils.add_derived_vars(data)
def test_add_tisr_var_variable_added(self):
data = xa.Dataset(
data_vars={
"var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
"time", np.array([10, 20], dtype="datetime64[D]")
),
},
)
data_utils.add_tisr_var(data)
self.assertIn(data_utils.TISR, set(data.variables))
def test_add_tisr_var_existing_var_not_overridden(self):
dims = ["time", "lat", "lon"]
data = xa.Dataset(
data_vars={
"var1": (dims, np.full((2, 2, 2), 8.0)),
data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
"time", np.array([10, 20], dtype="datetime64[D]")
),
},
)
data_utils.add_derived_vars(data)
np.testing.assert_allclose(data[data_utils.TISR], 1200.0)
def test_add_tisr_var_works_with_batch_dim_size_one(self):
data = xa.Dataset(
data_vars={
"var1": (
["batch", "time", "lat", "lon"],
np.full((1, 2, 2, 2), 8.0),
)
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
),
},
)
data_utils.add_tisr_var(data)
self.assertIn(data_utils.TISR, set(data.variables))
def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
data = xa.Dataset(
data_vars={
"var1": (
["batch", "time", "lat", "lon"],
np.full((2, 2, 2, 2), 8.0),
)
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
("batch", "time"),
np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
),
},
)
with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
data_utils.add_tisr_var(data)
if __name__ == "__main__":
absltest.main()
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX implementation of Graph Networks Simulator.
Generalization to TypedGraphs of the deep Graph Neural Network from:
@inproceedings{pfaff2021learning,
title={Learning Mesh-Based Simulation with Graph Networks},
author={Pfaff, Tobias and Fortunato, Meire and Sanchez-Gonzalez, Alvaro and
Battaglia, Peter},
booktitle={International Conference on Learning Representations},
year={2021}
}
@inproceedings{sanchez2020learning,
title={Learning to simulate complex physics with graph networks},
author={Sanchez-Gonzalez, Alvaro and Godwin, Jonathan and Pfaff, Tobias and
Ying, Rex and Leskovec, Jure and Battaglia, Peter},
booktitle={International conference on machine learning},
pages={8459--8468},
year={2020},
organization={PMLR}
}
"""
from typing import Mapping, Optional
from graphcast import typed_graph
from graphcast import typed_graph_net
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
class DeepTypedGraphNet(hk.Module):
"""Deep Graph Neural Network.
It works with TypedGraphs with typed nodes and edges. It runs message
passing on all of the node sets and all of the edge sets in the graph. For
each message passing step a `typed_graph_net.InteractionNetwork` is used to
update the full TypedGraph by using different MLPs for each of the node sets
and each of the edge sets.
If embed_{nodes,edges} is specified the node/edge features will be embedded
into a fixed dimensionality before running the first step of message passing.
If {node,edge}_output_size the final node/edge features will be embedded into
the specified output size.
This class may be used for shared or unshared message passing:
* num_message_passing_steps = N, num_processor_repetitions = 1, gives
N layers of message passing with fully unshared weights:
[W_1, W_2, ... , W_M] (default)
* num_message_passing_steps = 1, num_processor_repetitions = M, gives
N layers of message passing with fully shared weights:
[W_1] * M
* num_message_passing_steps = N, num_processor_repetitions = M, gives
M*N layers of message passing with both shared and unshared message passing
such that the weights used at each iteration are:
[W_1, W_2, ... , W_N] * M
"""
def __init__(self,
*,
node_latent_size: Mapping[str, int],
edge_latent_size: Mapping[str, int],
mlp_hidden_size: int,
mlp_num_hidden_layers: int,
num_message_passing_steps: int,
num_processor_repetitions: int = 1,
embed_nodes: bool = True,
embed_edges: bool = True,
node_output_size: Optional[Mapping[str, int]] = None,
edge_output_size: Optional[Mapping[str, int]] = None,
include_sent_messages_in_node_update: bool = False,
use_layer_norm: bool = True,
activation: str = "relu",
f32_aggregation: bool = False,
aggregate_edges_for_nodes_fn: str = "segment_sum",
aggregate_normalization: Optional[float] = None,
name: str = "DeepTypedGraphNet"):
"""Inits the model.
Args:
node_latent_size: Size of the node latent representations.
edge_latent_size: Size of the edge latent representations.
mlp_hidden_size: Hidden layer size for all MLPs.
mlp_num_hidden_layers: Number of hidden layers in all MLPs.
num_message_passing_steps: Number of unshared message passing steps
in the processor steps.
num_processor_repetitions: Number of times that the same processor is
applied sequencially.
embed_nodes: If False, the node embedder will be omitted.
embed_edges: If False, the edge embedder will be omitted.
node_output_size: Size of the output node representations for
each node type. For node types not specified here, the latent node
representation from the output of the processor will be returned.
edge_output_size: Size of the output edge representations for
each edge type. For edge types not specified here, the latent edge
representation from the output of the processor will be returned.
include_sent_messages_in_node_update: Whether to include pooled sent
messages from each node in the node update.
use_layer_norm: Whether it uses layer norm or not.
activation: name of activation function.
f32_aggregation: Use float32 in the edge aggregation.
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
node.
aggregate_normalization: An optional constant that normalizes the output
of aggregate_edges_for_nodes_fn. For context, this can be used to
reduce the shock the model undergoes when switching resolution, which
increase the number of edges connected to a node. In particular, this is
useful when using segment_sum, but should not be combined with
segment_mean.
name: Name of the model.
"""
super().__init__(name=name)
self._node_latent_size = node_latent_size
self._edge_latent_size = edge_latent_size
self._mlp_hidden_size = mlp_hidden_size
self._mlp_num_hidden_layers = mlp_num_hidden_layers
self._num_message_passing_steps = num_message_passing_steps
self._num_processor_repetitions = num_processor_repetitions
self._embed_nodes = embed_nodes
self._embed_edges = embed_edges
self._node_output_size = node_output_size
self._edge_output_size = edge_output_size
self._include_sent_messages_in_node_update = (
include_sent_messages_in_node_update)
self._use_layer_norm = use_layer_norm
self._activation = _get_activation_fn(activation)
self._initialized = False
self._f32_aggregation = f32_aggregation
self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn(
aggregate_edges_for_nodes_fn)
self._aggregate_normalization = aggregate_normalization
if aggregate_normalization:
# using aggregate_normalization only makes sense with segment_sum.
assert aggregate_edges_for_nodes_fn == "segment_sum"
def __call__(self,
input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Forward pass of the learnable dynamics model."""
self._networks_builder(input_graph)
# Embed input features (if applicable).
latent_graph_0 = self._embed(input_graph)
# Do `m` message passing steps in the latent graphs.
latent_graph_m = self._process(latent_graph_0)
# Compute outputs from the last latent graph (if applicable).
return self._output(latent_graph_m)
def _networks_builder(self, graph_template):
if self._initialized:
return
self._initialized = True
def build_mlp(name, output_size):
mlp = hk.nets.MLP(
output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
output_size], name=name + "_mlp", activation=self._activation)
return jraph.concatenated_args(mlp)
def build_mlp_with_maybe_layer_norm(name, output_size):
network = build_mlp(name, output_size)
if self._use_layer_norm:
layer_norm = hk.LayerNorm(
axis=-1, create_scale=True, create_offset=True,
name=name + "_layer_norm")
network = hk.Sequential([network, layer_norm])
return jraph.concatenated_args(network)
# The embedder graph network independently embeds edge and node features.
if self._embed_edges:
embed_edge_fn = _build_update_fns_for_edge_types(
build_mlp_with_maybe_layer_norm,
graph_template,
"encoder_edges_",
output_sizes=self._edge_latent_size)
else:
embed_edge_fn = None
if self._embed_nodes:
embed_node_fn = _build_update_fns_for_node_types(
build_mlp_with_maybe_layer_norm,
graph_template,
"encoder_nodes_",
output_sizes=self._node_latent_size)
else:
embed_node_fn = None
embedder_kwargs = dict(
embed_edge_fn=embed_edge_fn,
embed_node_fn=embed_node_fn,
)
self._embedder_network = typed_graph_net.GraphMapFeatures(
**embedder_kwargs)
if self._f32_aggregation:
def aggregate_fn(data, *args, **kwargs):
dtype = data.dtype
data = data.astype(jnp.float32)
output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
if self._aggregate_normalization:
output = output / self._aggregate_normalization
output = output.astype(dtype)
return output
else:
def aggregate_fn(data, *args, **kwargs):
output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
if self._aggregate_normalization:
output = output / self._aggregate_normalization
return output
# Create `num_message_passing_steps` graph networks with unshared parameters
# that update the node and edge latent features.
# Note that we can use `modules.InteractionNetwork` because
# it also outputs the messages as updated edge latent features.
self._processor_networks = []
for step_i in range(self._num_message_passing_steps):
self._processor_networks.append(
typed_graph_net.InteractionNetwork(
update_edge_fn=_build_update_fns_for_edge_types(
build_mlp_with_maybe_layer_norm,
graph_template,
f"processor_edges_{step_i}_",
output_sizes=self._edge_latent_size),
update_node_fn=_build_update_fns_for_node_types(
build_mlp_with_maybe_layer_norm,
graph_template,
f"processor_nodes_{step_i}_",
output_sizes=self._node_latent_size),
aggregate_edges_for_nodes_fn=aggregate_fn,
include_sent_messages_in_node_update=(
self._include_sent_messages_in_node_update),
))
# The output MLPs converts edge/node latent features into the output sizes.
output_kwargs = dict(
embed_edge_fn=_build_update_fns_for_edge_types(
build_mlp, graph_template, "decoder_edges_", self._edge_output_size)
if self._edge_output_size else None,
embed_node_fn=_build_update_fns_for_node_types(
build_mlp, graph_template, "decoder_nodes_", self._node_output_size)
if self._node_output_size else None,)
self._output_network = typed_graph_net.GraphMapFeatures(
**output_kwargs)
def _embed(
self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Embeds the input graph features into a latent graph."""
# Copy the context to all of the node types, if applicable.
context_features = input_graph.context.features
if jax.tree_util.tree_leaves(context_features):
# This code assumes a single input feature array for the context and for
# each node type.
assert len(jax.tree_util.tree_leaves(context_features)) == 1
new_nodes = {}
for node_set_name, node_set in input_graph.nodes.items():
node_features = node_set.features
broadcasted_context = jnp.repeat(
context_features, node_set.n_node, axis=0,
total_repeat_length=node_features.shape[0])
new_nodes[node_set_name] = node_set._replace(
features=jnp.concatenate(
[node_features, broadcasted_context], axis=-1))
input_graph = input_graph._replace(
nodes=new_nodes,
context=input_graph.context._replace(features=()))
# Embeds the node and edge features.
latent_graph_0 = self._embedder_network(input_graph)
return latent_graph_0
def _process(
self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Processes the latent graph with several steps of message passing."""
# Do `num_message_passing_steps` with each of the `self._processor_networks`
# with unshared weights, and repeat that `self._num_processor_repetitions`
# times.
latent_graph = latent_graph_0
for unused_repetition_i in range(self._num_processor_repetitions):
for processor_network in self._processor_networks:
latent_graph = self._process_step(processor_network, latent_graph)
return latent_graph
def _process_step(
self, processor_network_k,
latent_graph_prev_k: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Single step of message passing with node/edge residual connections."""
# One step of message passing.
latent_graph_k = processor_network_k(latent_graph_prev_k)
# Add residuals.
nodes_with_residuals = {}
for k, prev_set in latent_graph_prev_k.nodes.items():
nodes_with_residuals[k] = prev_set._replace(
features=prev_set.features + latent_graph_k.nodes[k].features)
edges_with_residuals = {}
for k, prev_set in latent_graph_prev_k.edges.items():
edges_with_residuals[k] = prev_set._replace(
features=prev_set.features + latent_graph_k.edges[k].features)
latent_graph_k = latent_graph_k._replace(
nodes=nodes_with_residuals, edges=edges_with_residuals)
return latent_graph_k
def _output(self,
latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
"""Produces the output from the latent graph."""
return self._output_network(latent_graph)
def _build_update_fns_for_node_types(
builder_fn, graph_template, prefix, output_sizes=None):
"""Builds an update function for all node types or a subset of them."""
output_fns = {}
for node_set_name in graph_template.nodes.keys():
if output_sizes is None:
# Use the default output size for all types.
output_size = None
else:
# Otherwise, ignore any type that does not have an explicit output size.
if node_set_name in output_sizes:
output_size = output_sizes[node_set_name]
else:
continue
output_fns[node_set_name] = builder_fn(
f"{prefix}{node_set_name}", output_size)
return output_fns
def _build_update_fns_for_edge_types(
builder_fn, graph_template, prefix, output_sizes=None):
"""Builds an edge function for all node types or a subset of them."""
output_fns = {}
for edge_set_key in graph_template.edges.keys():
edge_set_name = edge_set_key.name
if output_sizes is None:
# Use the default output size for all types.
output_size = None
else:
# Otherwise, ignore any type that does not have an explicit output size.
if edge_set_name in output_sizes:
output_size = output_sizes[edge_set_name]
else:
continue
output_fns[edge_set_name] = builder_fn(
f"{prefix}{edge_set_name}", output_size)
return output_fns
def _get_activation_fn(name):
"""Return activation function corresponding to function_name."""
if name == "identity":
return lambda x: x
if hasattr(jax.nn, name):
return getattr(jax.nn, name)
if hasattr(jnp, name):
return getattr(jnp, name)
raise ValueError(f"Unknown activation function {name} specified.")
def _get_aggregate_edges_for_nodes_fn(name):
"""Return aggregate_edges_for_nodes_fn corresponding to function_name."""
if hasattr(jraph, name):
return getattr(jraph, name)
raise ValueError(
f"Unknown aggregate_edges_for_nodes_fn function {name} specified.")
This diff is collapsed.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for converting from regular grids on a sphere, to triangular meshes."""
from graphcast import icosahedral_mesh
import numpy as np
import scipy
import trimesh
def _grid_lat_lon_to_coordinates(
grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
"""Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
# Convert to spherical coordinates phi and theta defined in the grid.
# Each [num_latitude_points, num_longitude_points]
phi_grid, theta_grid = np.meshgrid(
np.deg2rad(grid_longitude),
np.deg2rad(90 - grid_latitude))
# [num_latitude_points, num_longitude_points, 3]
# Note this assumes unit radius, since for now we model the earth as a
# sphere of unit radius, and keep any vertical dimension as a regular grid.
return np.stack(
[np.cos(phi_grid)*np.sin(theta_grid),
np.sin(phi_grid)*np.sin(theta_grid),
np.cos(theta_grid)], axis=-1)
def radius_query_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh,
radius: float) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for radius query.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
radius: Radius of connectivity in R3. for a sphere of unit radius.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh such that the distances in a straight line (not geodesic)
are smaller than or equal to `radius`.
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
# [num_mesh_points, 3]
mesh_positions = mesh.vertices
kd_tree = scipy.spatial.cKDTree(mesh_positions)
# [num_grid_points, num_mesh_points_per_grid_point]
# Note `num_mesh_points_per_grid_point` is not constant, so this is a list
# of arrays, rather than a 2d array.
query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
grid_edge_indices = []
mesh_edge_indices = []
for grid_index, mesh_neighbors in enumerate(query_indices):
grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
mesh_edge_indices.append(mesh_neighbors)
# [num_edges]
grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
return grid_edge_indices, mesh_edge_indices
def in_mesh_triangle_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for grid points contained in mesh triangles.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh vertices of the triangle that contain each grid point.
The number of edges is always num_lat_points * num_lon_points * 3
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
# [num_grid_points] with mesh face indices for each grid point.
_, _, query_face_indices = trimesh.proximity.closest_point(
mesh_trimesh, grid_positions)
# [num_grid_points, 3] with mesh node indices for each grid point.
mesh_edge_indices = mesh.faces[query_face_indices]
# [num_grid_points, 3] with grid node indices, where every row simply contains
# the row (grid_point) index.
grid_indices = np.arange(grid_positions.shape[0])
grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
# Flatten to get a regular list.
# [num_edges=num_grid_points*3]
mesh_edge_indices = mesh_edge_indices.reshape([-1])
grid_edge_indices = grid_edge_indices.reshape([-1])
return grid_edge_indices, mesh_edge_indices
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for graphcast.grid_mesh_connectivity."""
from absl.testing import absltest
from graphcast import grid_mesh_connectivity
from graphcast import icosahedral_mesh
import numpy as np
class GridMeshConnectivityTest(absltest.TestCase):
def test_grid_lat_lon_to_coordinates(self):
# Intervals of 30 degrees.
grid_latitude = np.array([-45., 0., 45])
grid_longitude = np.array([0., 90., 180., 270.])
inv_sqrt2 = 1 / np.sqrt(2)
expected_coordinates = np.array([
[[inv_sqrt2, 0., -inv_sqrt2],
[0., inv_sqrt2, -inv_sqrt2],
[-inv_sqrt2, 0., -inv_sqrt2],
[0., -inv_sqrt2, -inv_sqrt2]],
[[1., 0., 0.],
[0., 1., 0.],
[-1., 0., 0.],
[0., -1., 0.]],
[[inv_sqrt2, 0., inv_sqrt2],
[0., inv_sqrt2, inv_sqrt2],
[-inv_sqrt2, 0., inv_sqrt2],
[0., -inv_sqrt2, inv_sqrt2]],
])
coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude)
np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
def test_radius_query_indices_smoke(self):
# TODO(alvarosg): Add non-smoke test?
grid_latitude = np.linspace(-75, 75, 6)
grid_longitude = np.arange(12) * 30.
mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=3)[-1]
grid_mesh_connectivity.radius_query_indices(
grid_latitude=grid_latitude,
grid_longitude=grid_longitude,
mesh=mesh, radius=0.2)
def test_in_mesh_triangle_indices_smoke(self):
# TODO(alvarosg): Add non-smoke test?
grid_latitude = np.linspace(-75, 75, 6)
grid_longitude = np.arange(12) * 30.
mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
splits=3)[-1]
grid_mesh_connectivity.in_mesh_triangle_indices(
grid_latitude=grid_latitude,
grid_longitude=grid_longitude,
mesh=mesh)
if __name__ == "__main__":
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment