jax>=0.2.8 jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8 -f https://download.pytorch.org/whl/torch_stable.html
torch==1.11.0+cpu
-f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.12.0+cpu