absl-py>=0.12.0
aqtp!=0.1.1  # https://github.com/google/aqt/issues/196
chex>=0.0.7
clu>=0.0.3  # 0.0.9
einops>=0.3.0
flax>=0.6.4 #0.8.4、optax-0.2.2、orbax-checkpoint==0.4.1 
# git+https://github.com/google/flaxformer

jax[cuda11_cudnn86]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

ml-collections>=0.1.0
numpy>=1.19.5
pandas>=1.1.0
tensorflow-cpu>=2.4.0  #tensorflow-cpu==2.14.0    # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets>=4.0.1
tensorflow-probability>=0.11.1
tensorflow-text>=2.9.0