In the following we will explain how to install all relevant libraries on your local computer and on TPU VM.
It is recommended to install all relevant libraries both on your local machine
It is recommended to install all relevant libraries both on your local machine
and on the TPU virtual machine. This way, quick prototyping and testing can be done on
and on the TPU virtual machine. This way, quick prototyping and testing can be done on
your local machine and the actual training can be done on the TPU VM.
your local machine and the actual training can be done on the TPU VM.
### Local computer
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets:
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets:
-[JAX](https://github.com/google/jax/)
-[JAX](https://github.com/google/jax/)
...
@@ -250,15 +254,137 @@ You should install the above libraries in a [virtual environment](https://docs.p
...
@@ -250,15 +254,137 @@ You should install the above libraries in a [virtual environment](https://docs.p
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
to use and activate it.
to use and activate it.
You should be able to run the command:
```bash
python3 -m venv <your-venv-name>
```
You can activate your venv by running
```bash
source ~/<your-venv-name>/bin/activate
```
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
**IMPORTANT**: If you are setting up your environment on a TPU VM, make sure to
1. Fork the [repository](https://github.com/huggingface/transformers) by
install JAX's TPU version before cloning and installing the transformers repository.
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
Otherwise, an incorrect version of JAX will be installed, and the following commands will
under your GitHub user account.
throw an error.
To install JAX's TPU version first run the following command:
2. Clone your fork to your local disk, and add the base repository as a remote:
You should install the above libraries in a [virtual environment](https://docs.python.org/3/library/venv.html).
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
to use and activate it.
You should be able to run the command:
```bash
python3 -m venv <your-venv-name>
```
If this doesn't work, you first might to have install `python3-venv`. You can do this as follows:
```bash
sudo apt-get install python3-venv
```
You can activate your venv by running
```bash
source ~/<your-venv-name>/bin/activate
```
Next you should install JAX's TPU version on TPU by running the following command:
Jax should have been installed correctly nevertheless.
To verify that JAX was correctly installed, you can run the following command:
To verify that JAX was correctly installed, you can run the following command:
```python
```python
...
@@ -279,7 +423,9 @@ jax.device_count()
...
@@ -279,7 +423,9 @@ jax.device_count()
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
Now you can run the following steps as usual.
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
1. Fork the [repository](https://github.com/huggingface/transformers) by
1. Fork the [repository](https://github.com/huggingface/transformers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
...
@@ -352,7 +498,6 @@ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
...
@@ -352,7 +498,6 @@ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
model(input_ids)
model(input_ids)
```
```
## Quickstart flax and jax
## Quickstart flax and jax
[JAX](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html).
[JAX](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html).