nonlinearities.py 400 Bytes
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
"""Registry of nonlinearities that can be used in neural networks."""

import gin
import jax
import jax.numpy as jnp


relu = gin.external_configurable(jax.nn.relu)
tanh = gin.external_configurable(jnp.tanh)
softplus = gin.external_configurable(jax.nn.softplus)
swish = gin.external_configurable(jax.nn.swish)
elu = gin.external_configurable(jax.nn.elu)
gelu = gin.external_configurable(jax.nn.gelu)