datasets<4.0.0 flax>=0.7.1 nltk>=3.8.2 optax