towers_test.py 2.67 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Tests for google3.research.simulation.whirl.models.towers."""

import itertools
from absl.testing import absltest
from absl.testing import parameterized

import gin
import haiku as hk
import jax
from jax_cfd.base import test_util
from jax_cfd.ml import towers  # pylint: disable=unused-import


TOWERS = ['towers.forward_tower_factory', 'towers.residual_block_tower_factory']
SCALE_FNS = ['towers.fixed_scale', 'towers.scale_to_range']
NDIMS = [1, 2, 3]
INPUT_CHANNELS = [1, 6]


def test_parameters():
  product = itertools.product(TOWERS, SCALE_FNS, NDIMS, INPUT_CHANNELS)
  parameters = []
  for tower, scale_fn, ndim, input_channels in product:
    name = '_'.join([tower, scale_fn, f'{ndim}D', f'{input_channels}_channels'])
    parameters.append(dict(
        testcase_name=name,
        tower_module=tower,
        scale_fn_module=scale_fn,
        ndim=ndim,
        input_channels=input_channels))
  return parameters


@gin.configurable
def forward_pass_module(
    num_output_channels,
    ndim,
    tower_module=gin.REQUIRED
):
  """Constructs a function that initializes tower and applies it to inputs."""
  def forward_pass(inputs):
    return tower_module(num_output_channels, ndim)(inputs)

  return forward_pass


class TowersTest(test_util.TestCase):
  """Tests towers construction, configuration and composition."""

  def setUp(self):
    """Configures all scale_fns that have gin.REQUIRED values."""
    super().setUp()
    gin.enter_interactive_mode()
    config = '\n'.join([
        'towers.fixed_scale.rescaled_one = 0.3',
        'towers.scale_to_range.min_value = -1.23',
        'towers.scale_to_range.max_value = 1.21'
    ])
    gin.parse_config(config)

  @parameterized.named_parameters(*test_parameters())
  def test_output_shapes(
      self,
      tower_module,
      scale_fn_module,
      ndim,
      input_channels
  ):
    """Tests that towers produce outputs of expected shapes."""
    gin.enter_interactive_mode()
    config = '\n'.join([
        f'forward_pass_module.tower_module = @{tower_module}',
        f'{tower_module}.inputs_scale_fn = @{scale_fn_module}'
    ])
    gin.parse_config(config)

    num_output_channels = 5
    spatial_size = 17
    rng = jax.random.PRNGKey(42)
    inputs = jax.random.uniform(rng, (spatial_size,) * ndim + (input_channels,))

    forward_pass = hk.without_apply_rng(
        hk.transform(forward_pass_module(num_output_channels, ndim)))
    params = forward_pass.init(rng, inputs)
    output = forward_pass.apply(params, inputs)
    expected_output_shape = inputs.shape[:-1] + (num_output_channels,)
    actual_output_shape = output.shape
    self.assertEqual(actual_output_shape, expected_output_shape)


if __name__ == '__main__':
  absltest.main()