import jax import pymc as pm print(jax.default_backend()) print(jax.devices())