physics_specifications.py 1.27 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Modules with PhysicsSpecifications for various equations.

To ensure that all components of the pipeline obtain the expected PhysicsSpecs
all modules (except specializing on a particular equation) must instantiate
PhysicsSpecs objects using `get_physics_specs`, which should be
configured appropriately.
"""

import dataclasses
from typing import Optional

import gin
from jax_cfd.ml import forcings


ForcingModule = forcings.ForcingModule


@gin.configurable
def get_physics_specs(physics_specs_cls=gin.REQUIRED):
  """Returns an instance of `physics_specs_cls`, configured by gin."""
  return physics_specs_cls()


@gin.register
@dataclasses.dataclass
class BasePhysicsSpecs:
  """Base class for keeping physical parameters and forcing module."""
  forcing_module: Optional[ForcingModule]


@gin.register
@dataclasses.dataclass
class KsPhysicsSpecs(BasePhysicsSpecs):
  """Configurable physical parameters for Kuramoto-Sivashinsky models."""


@gin.register
@dataclasses.dataclass
class NavierStokesPhysicsSpecs(BasePhysicsSpecs):
  """Configurable physical parameters and modules for Navier-Stokes models."""
  density: float
  viscosity: float


@gin.configurable
@dataclasses.dataclass
class SpectralNavierStokesPhysicsSpecs(BasePhysicsSpecs):
  viscosity: float
  drag: float
  smooth: bool