{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "SCWieCN1c_oP" }, "source": [ "# ML accelerated CFD data analysis\n", "\n", "This notebook reproduces key figures in our [PNAS paper](https://www.pnas.org/content/118/21/e2101784118) based on saved datasets. The data is stored in netCDF files in Google Cloud Storage, and the analysis uses xarray and JAX-CFD.\n", "\n", "> Indented block\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 12165, "status": "ok", "timestamp": 1635551082356, "user": { "displayName": "Stephan Hoyer", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gh3-wMvU44jUaVFR9jlCY_2pss4FrdtAZbLsUaV=s64", "userId": "01386112912994523038" }, "user_tz": 420 }, "id": "dq5Hou4QzH0Q", "outputId": "74a35f28-bc46-4678-bef4-a696f79bb482" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/modelzoo/jax/jax-cfd\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" ] } ], "source": [ "%cd /home/modelzoo/jax/jax-cfd/" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://mirrors.ustc.edu.cn/pypi/web/simple\n", "Requirement already satisfied: xarray in /usr/local/lib/python3.10/site-packages (2024.5.0)\n", "^C\n" ] } ], "source": [ "! pip install -U xarray jax-cfd[data]==0.1.0" ] }, { "cell_type": "markdown", "metadata": { "id": "X852u7zJiRXq" }, "source": [ "## Figure 1\n", "\n", "Replication of the Figure 1 from the PNAS paper, except with a [bootstrap](https://en.wikipedia.org/wiki/Bootstrapping_(statistics)) based estimation of uncertainty, given the sample size of 16 trajectories." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cellView": "form", "id": "9NMXlZX0iX3o" }, "outputs": [], "source": [ "# @title Utility functions\n", "import xarray\n", "import seaborn\n", "import numpy as np\n", "import pandas as pd\n", "import jax_cfd.data.xarray_utils as xru\n", "from jax_cfd.data import evaluation\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def correlation(x, y):\n", " state_dims = ['x', 'y']\n", " p = xru.normalize(x, state_dims) * xru.normalize(y, state_dims)\n", " return p.sum(state_dims)\n", "\n", "def calculate_time_until(vorticity_corr):\n", " threshold = 0.95\n", " return (vorticity_corr.mean('sample') >= threshold).idxmin('time').rename('time_until')\n", "\n", "def calculate_time_until_bootstrap(vorticity_corr, bootstrap_samples=10000):\n", " rs = np.random.RandomState(0)\n", " indices = rs.choice(16, size=(10000, 16), replace=True)\n", " boot_vorticity_corr = vorticity_corr.isel(\n", " sample=(('boot', 'sample2'), indices)).rename({'sample2': 'sample'})\n", " return calculate_time_until(boot_vorticity_corr)\n", "\n", "def calculate_upscaling(time_until):\n", " slope = ((np.log(16) - np.log(8))\n", " / (time_until.sel(model='baseline_1024')\n", " - time_until.sel(model='baseline_512')))\n", " x = time_until.sel(model='learned_interp_64')\n", " x0 = time_until.sel(model='baseline_512')\n", " intercept = np.log(8)\n", " factor = np.exp(slope * (x - x0) + intercept)\n", " return factor\n", "\n", "def calculate_speedup(time_until):\n", " runtime_baseline_8x = 44.053293\n", " runtime_baseline_16x = 412.725656\n", " runtime_learned = 1.155115\n", " slope = ((np.log(runtime_baseline_16x) - np.log(runtime_baseline_8x))\n", " / (time_until.sel(model='baseline_1024')\n", " - time_until.sel(model='baseline_512')))\n", " x = time_until.sel(model='learned_interp_64')\n", " x0 = time_until.sel(model='baseline_512')\n", " intercept = np.log(runtime_baseline_8x)\n", " speedups = np.exp(slope * (x - x0) + intercept) / runtime_learned\n", " return speedups" ] }, { "cell_type": "markdown", "metadata": { "id": "nfxj9dxTKAUF" }, "source": [ "### Load data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 55729, "status": "ok", "timestamp": 1635551143667, "user": { "displayName": "Stephan Hoyer", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gh3-wMvU44jUaVFR9jlCY_2pss4FrdtAZbLsUaV=s64", "userId": "01386112912994523038" }, "user_tz": 420 }, "id": "UNivrWIsh3tL", "outputId": "8b5cb028-6413-444a-fdec-ad54b6e3b209" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/bin/bash: gsutil: command not found\n", "CPU times: user 4.12 ms, sys: 16 ms, total: 20.1 ms\n", "Wall time: 300 ms\n" ] } ], "source": [ "%time ! gsutil -m cp -r gs://gresearch/jax-cfd/public_eval_datasets/kolmogorov_re_1000_fig1 ../content" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1635551394768, "user": { "displayName": "Stephan Hoyer", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gh3-wMvU44jUaVFR9jlCY_2pss4FrdtAZbLsUaV=s64", "userId": "01386112912994523038" }, "user_tz": 420 }, "id": "y2hNMbdniSdV", "outputId": "5ddd98a5-5f5b-4a69-917f-efd58ee26deb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/modelzoo/jax/jax-cfd\n", "baseline_1024x1024.nc baseline_32x32.nc learned_32x32.nc\n", "baseline_128x128.nc baseline_512x512.nc learned_64x64.nc\n", "baseline_2048x2048.nc baseline_64x64.nc tpu-speed-measurements.csv\n", "baseline_256x256.nc learned_128x128.nc\n" ] } ], "source": [ "! pwd\n", "! ls ./content/kolmogorov_re_1000_fig1" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "JZOgcLcniX3p" }, "outputs": [], "source": [ "baseline_filenames = {\n", " f'baseline_{r}': f'baseline_{r}x{r}.nc'\n", " for r in [64, 128, 256, 512, 1024, 2048]\n", "}\n", "learned_filenames = {\n", " f'learned_interp_{r}': f'learned_{r}x{r}.nc'\n", " for r in [32, 64, 128]\n", "}\n", "\n", "models = {}\n", "for k, v in baseline_filenames.items():\n", " models[k] = xarray.open_dataset(f'./content/kolmogorov_re_1000_fig1/{v}', chunks={'time': '100MB'})\n", "for k, v in learned_filenames.items():\n", " ds = xarray.open_dataset(f'./content/kolmogorov_re_1000_fig1/{v}', chunks={'time': '100MB'})\n", " models[k] = ds.reindex_like(models['baseline_64'], method='nearest')\n", "\n", "combined_fig1 = xarray.concat(list(models.values()), dim='model')\n", "combined_fig1.coords['model'] = list(models.keys())\n", "combined_fig1['vorticity'] = xru.vorticity_2d(combined_fig1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 327 }, "executionInfo": { "elapsed": 234, "status": "ok", "timestamp": 1635551395424, "user": { "displayName": "Stephan Hoyer", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gh3-wMvU44jUaVFR9jlCY_2pss4FrdtAZbLsUaV=s64", "userId": "01386112912994523038" }, "user_tz": 420 }, "id": "g-qiSd9vx2qv", "outputId": "4303b785-ceb2-4ab7-d9ce-c29fd5e17e4d" }, "outputs": [ { "data": { "text/html": [ "
<xarray.Dataset> Size: 6GB\n",
"Dimensions: (model: 9, sample: 16, time: 2441, x: 32, y: 32)\n",
"Coordinates:\n",
" * time (time) float64 20kB 0.0 0.01402 0.02805 ... 34.19 34.21 34.22\n",
" * x (x) float64 256B 0.09817 0.2945 0.4909 ... 5.792 5.989 6.185\n",
" * y (y) float64 256B 0.09817 0.2945 0.4909 ... 5.792 5.989 6.185\n",
" * sample (sample) int32 64B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n",
" * model (model) <U18 648B 'baseline_64' ... 'learned_interp_128'\n",
"Data variables:\n",
" u (model, sample, time, x, y) float32 1GB dask.array<chunksize=(1, 16, 1525, 32, 32), meta=np.ndarray>\n",
" v (model, sample, time, x, y) float32 1GB dask.array<chunksize=(1, 16, 1525, 32, 32), meta=np.ndarray>\n",
" vorticity (model, sample, time, x, y) float64 3GB dask.array<chunksize=(1, 16, 1525, 32, 32), meta=np.ndarray>\n",
"Attributes: (12/25)\n",
" domain_size: 6.283185307179586\n",
" domain_size_multiple: 1\n",
" full_config_str: import google3.research.simulation.whirl.m...\n",
" init_cfl_safety_factor: 0.5\n",
" init_peak_wavenumber: 4.0\n",
" maximum_velocity: 7.0\n",
" ... ...\n",
" time_subsample_factor: 1\n",
" tracing_max_duration_in_msec: 100.0\n",
" warmup_grid_size: 2048\n",
" warmup_time: 40.0\n",
" xm_experiment_id: 18497215\n",
" xm_work_unit_id: 6| \n", " | model | \n", "resolution | \n", "msec_per_sim_step | \n", "model_name | \n", "msec_per_dt | \n", "
|---|---|---|---|---|---|
| 0 | \n", "DS | \n", "512 | \n", "0.183315 | \n", "baseline_64 | \n", "0.366630 | \n", "
| 1 | \n", "DS | \n", "1024 | \n", "0.862311 | \n", "baseline_128 | \n", "3.449244 | \n", "
| 2 | \n", "DS | \n", "2048 | \n", "3.484289 | \n", "baseline_256 | \n", "27.874312 | \n", "
| 3 | \n", "DS | \n", "4096 | \n", "19.306356 | \n", "baseline_512 | \n", "308.901689 | \n", "
| 4 | \n", "DS | \n", "8192 | \n", "90.438509 | \n", "baseline_1024 | \n", "2894.032298 | \n", "
| 5 | \n", "LI | \n", "256 | \n", "1.111596 | \n", "learned_interp_32 | \n", "1.111596 | \n", "
| 6 | \n", "LI | \n", "512 | \n", "4.049835 | \n", "learned_interp_64 | \n", "8.099669 | \n", "
| 7 | \n", "LI | \n", "1024 | \n", "16.043913 | \n", "learned_interp_128 | \n", "64.175654 | \n", "
<xarray.Dataset> Size: 6GB\n",
"Dimensions: (time: 3477, x: 32, y: 32, sample: 16, model: 7)\n",
"Coordinates:\n",
" * time (time) float64 28kB 0.0 0.07012 0.1402 ... 243.6 243.7 243.8\n",
" * x (x) float64 256B 0.1473 0.3436 0.54 0.7363 ... 5.841 6.038 6.234\n",
" * y (y) float64 256B 0.1473 0.3436 0.54 0.7363 ... 5.841 6.038 6.234\n",
" * sample (sample) int32 64B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15\n",
" * model (model) <U17 476B 'baseline_64' ... 'learned_interp_64'\n",
"Data variables:\n",
" u (model, sample, time, x, y) float32 2GB dask.array<chunksize=(1, 16, 381, 32, 32), meta=np.ndarray>\n",
" v (model, sample, time, x, y) float32 2GB dask.array<chunksize=(1, 16, 381, 32, 32), meta=np.ndarray>\n",
" vorticity (model, sample, time, x, y) float64 3GB dask.array<chunksize=(1, 16, 381, 32, 32), meta=np.ndarray>\n",
"Attributes: (12/17)\n",
" domain_size: [0. 6.28318531]\n",
" domain_size_multiple: 1\n",
" full_config_str: \\n# Macros:\\n# ===========================...\n",
" init_cfl_safety_factor: 0.5\n",
" init_peak_wavenumber: 4.0\n",
" maximum_velocity: 7.0\n",
" ... ...\n",
" simulation_time: 240.0\n",
" stable_time_step: 0.007012483601762931\n",
" time_subsample_factor: 1\n",
" tracing_max_duration_in_msec: 100.0\n",
" warmup_grid_size: 2048\n",
" warmup_time: 40.0<xarray.Dataset> Size: 17MB\n",
"Dimensions: (model: 7, time: 3477, k: 30)\n",
"Coordinates:\n",
" * time (time) float64 28kB 0.0 0.07012 ... 243.8\n",
" * k (k) float64 240B 1.0 2.0 3.0 ... 29.0 30.0\n",
" * model (model) <U17 476B 'baseline_64' ... 'learne...\n",
"Data variables: (12/31)\n",
" u_mean (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" v_mean (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" kinetic_energy_mean (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" speed_mean (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" energy_spectrum_mean (model, k, time) float64 6MB dask.array<chunksize=(1, 30, 381), meta=np.ndarray>\n",
" enstrophy_mean (model, time) float64 195kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" ... ...\n",
" energy_spectrum_within_q=1.0 (model, k, time) bool 730kB dask.array<chunksize=(1, 30, 381), meta=np.ndarray>\n",
" enstrophy_within_q=1.0 (model, time) float64 195kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" vorticity_within_q=1.0 (model, time) float64 195kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" u_correlation (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" v_correlation (model, time) float32 97kB dask.array<chunksize=(1, 381), meta=np.ndarray>\n",
" vorticity_correlation (model, time) float64 195kB dask.array<chunksize=(1, 381), meta=np.ndarray>