"docs/reference/py.rst" did not exist on "63c5582ab0c8a575641bb2e04591336bf5301eb5"
visualization.py 4.88 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities."""

from typing import Any, BinaryIO, Callable, Optional, List, Tuple, Union
from jax_cfd.base import grids
import matplotlib as mpl
import matplotlib.cm as cm
import numpy as np
import PIL.Image as Image
import seaborn as sns


NormFn = Callable[[grids.Array, int], mpl.colors.Normalize]


def quantile_normalize_fn(
    image_data: grids.Array,
    image_num: int,
    quantile: float = 0.999
) -> mpl.colors.Normalize:
  """Returns `mpl.colors.Normalize` object that range defined by data quantile.

  Args:
    image_data: data for which `Normalize` object is produced.
    image_num: number of frame in the series. Not used.
    quantile: quantile that should be included in the range.

  Returns:
    `mpl.colors.Normalize` that covers the range of values that include quantile
    of `image_data` values.
  """
  del image_num  # not used by `quantile_normalize_fn`.
  max_to_include = np.quantile(abs(image_data), quantile)
  norm = mpl.colors.Normalize(vmin=-max_to_include, vmax=max_to_include)
  return norm


def resize_image(
    image: Image.Image,
    longest_side: int,
    resample: int = Image.Resampling.NEAREST,
) -> Image.Image:
  """Resize an image, preserving its aspect ratio."""
  resize_factor = longest_side / max(image.size)
  new_size = tuple(round(s * resize_factor) for s in image.size)
  return image.resize(new_size, resample)


def trajectory_to_images(
    trajectory: grids.Array,
    compute_norm_fn: NormFn = quantile_normalize_fn,
    cmap: mpl.colors.ListedColormap = sns.cm.icefire,  # pytype: disable=module-attr
    longest_side: Optional[int] = None,
) -> List[Image.Image]:
  """Converts scalar trajectory with leading time axis into a list of images."""
  images = []
  for i, image_data in enumerate(trajectory):
    norm = compute_norm_fn(image_data, i)
    mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
    img = Image.fromarray(mappable.to_rgba(image_data, bytes=True))
    if longest_side is not None:
      img = resize_image(img, longest_side)
    images.append(img)
  return images


# TODO(dkochkov) consider generalizing this to a general facet.
def horizontal_facet(
    separate_images: List[List[Image.Image]],
    relative_separation_width: float,
    separation_rgb: Tuple[int, int, int] = (255, 255, 255)
) -> List[Image.Image]:
  """Stitches separate images into a single one with a separation strip.

  Args:
    separate_images: lists of images each representing time series. All images
      must have the same size.
    relative_separation_width: width of the separation defined as a fraction of
      a separate image.
    separation_rgb: rgb color code of the separation strip to add between
      adjacent images.

  Returns:
    list of merged images that contain images passed as `separate_images` with
    a separating strip.
  """
  images = []
  for frames in zip(*separate_images):
    images_to_combine = len(frames)
    separation_width = round(frames[0].width * relative_separation_width)
    image_height = frames[0].height
    image_width = (frames[0].width * images_to_combine +
                   separation_width * (images_to_combine - 1))
    full_im = Image.new('RGB', (image_width, image_height))

    sep_im = Image.new('RGB', (separation_width, image_height), separation_rgb)
    full_im = Image.new('RGB', (image_width, image_height))

    width_offset = 0
    height_offset = 0
    for frame in frames:
      full_im.paste(frame, (width_offset, height_offset))
      width_offset += frame.width
      if width_offset < full_im.width:
        full_im.paste(sep_im, (width_offset, height_offset))
        width_offset += sep_im.width
    images.append(full_im)
  return images


def save_movie(
    images: List[Image.Image],
    output_path: Union[str, BinaryIO],
    duration: float = 150.,
    loop: int = 0,
    **kwargs: Any
):
  """Saves `images` as a movie of duration `duration` to `output_path`.

  Args:
    images: list of images representing time series that will be saved as movie.
    output_path: file handle or cns path to where save the movie.
    duration: duration of the movie in milliseconds.
    loop: number of times to loop the movie. 0 interpreted as indefinite.
    **kwargs: additional keyword arguments to be passed to `Image.save`.
  """
  images[0].save(output_path, save_all=True, append_images=images[1:],
                 duration=duration, loop=loop, **kwargs)