optimizer_modules.py 518 Bytes
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Configurable optimizers from JAX."""
import gin
from jax.example_libraries import optimizers


@gin.configurable
def optimizer(value):
  return value


gin.external_configurable(optimizers.adam)
gin.external_configurable(optimizers.momentum)
gin.external_configurable(optimizers.nesterov)

gin.external_configurable(optimizers.exponential_decay)
gin.external_configurable(optimizers.inverse_time_decay)
gin.external_configurable(optimizers.polynomial_decay)
gin.external_configurable(optimizers.piecewise_constant)