Unverified Commit 87918d32 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/Flax] add a section about GPUs (#15198)



* add a section about GPUs

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b8810847
...@@ -51,6 +51,15 @@ Consider applying for the [Google TPU Research Cloud project](https://sites.rese ...@@ -51,6 +51,15 @@ Consider applying for the [Google TPU Research Cloud project](https://sites.rese
Each example README contains more details on the specific model and training Each example README contains more details on the specific model and training
procedure. procedure.
## Running on single or multiple GPUs
All of our JAX/Flax examples also run efficiently on single and multiple GPUs. You can use the same instructions in the README to launch training on GPU.
Distributed training is supported out-of-the box and scripts will use all the GPUs that are detected.
You should follow this [guide for installing JAX on GPUs](https://github.com/google/jax/#pip-installation-gpu-cuda) since the installation depends on
your CUDA and CuDNN version.
## Supported models ## Supported models
Porting models from PyTorch to JAX/Flax is an ongoing effort. Porting models from PyTorch to JAX/Flax is an ongoing effort.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment