test_gpu.py 78 Bytes
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
import jax
import pymc as pm
print(jax.default_backend())
print(jax.devices())