Unverified Commit b21905e0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Update README.md

parent d24a5231
......@@ -234,10 +234,14 @@ datasets["train"] = datasets["train"].select(range(1000))
## How to install relevant libraries
In the following we will explain how to install all relevant libraries on your local computer and on TPU VM.
It is recommended to install all relevant libraries both on your local machine
and on the TPU virtual machine. This way, quick prototyping and testing can be done on
your local machine and the actual training can be done on the TPU VM.
### Local computer
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets:
- [JAX](https://github.com/google/jax/)
......@@ -250,15 +254,137 @@ You should install the above libraries in a [virtual environment](https://docs.p
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
to use and activate it.
You should be able to run the command:
```bash
python3 -m venv <your-venv-name>
```
You can activate your venv by running
```bash
source ~/<your-venv-name>/bin/activate
```
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
**IMPORTANT**: If you are setting up your environment on a TPU VM, make sure to
install JAX's TPU version before cloning and installing the transformers repository.
Otherwise, an incorrect version of JAX will be installed, and the following commands will
throw an error.
To install JAX's TPU version first run the following command:
1. Fork the [repository](https://github.com/huggingface/transformers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
$ git clone https://github.com/<your Github handle>/transformers.git
$ cd transformers
$ git remote add upstream https://github.com/huggingface/transformers.git
```
3. Create a new branch to hold your development changes. This is especially useful to share code changes with your team:
```bash
$ git checkout -b a-descriptive-name-for-my-project
```
4. Set up a flax environment by running the following command in a virtual environment:
```bash
$ pip install -e ".[flax]"
```
(If transformers was already installed in the virtual environment, remove
it with `pip uninstall transformers` before reinstalling it in editable
mode with the `-e` flag.)
If you have already cloned that repo, you might need to `git pull` to get the most recent changes in the `datasets`
library.
Running this command will automatically install `flax`, `jax` and `optax`.
Next, you should also install the 🤗 Datasets library. We strongly recommend installing the
library from source to profit from the most current additions during the community week.
Simply run the following steps:
```
$ cd ~/
$ git clone https://github.com/huggingface/datasets.git
$ cd datasets
$ pip install -e ".[streaming]"
```
If you plan on contributing a specific dataset during
the community week, please fork the datasets repository and follow the instructions
[here](https://github.com/huggingface/datasets/blob/master/CONTRIBUTING.md#how-to-create-a-pull-request).
To verify that all libraries are correctly installed, you can run the following command.
It assumes that both `transformers` and `datasets` were installed from master - otherwise
datasets streaming will not work correctly.
```python
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
dummy_input = next(iter(dataset))["text"]
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)
```
### TPU VM
**VERY IMPORTANT** - Only one process can access the TPU cores at a time. This means that if multiple team members
are trying to connect to the TPU cores errors, such as:
```
libtpu.so already in used by another process. Not attempting to load libtpu.so in this process.
```
are thrown. As a conclusion, we recommend every team member to create her/his own virtual environment, but only one
person should run the heavy training processes. Also, please take turns when setting up the TPUv3-8 so that everybody
can verify that JAX is correctly installed.
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets on TPU VM:
- [JAX](https://github.com/google/jax/)
- [Flax](https://github.com/google/flax)
- [Optax](https://github.com/deepmind/optax)
- [Transformers](https://github.com/huggingface/transformers)
- [Datasets](https://github.com/huggingface/datasets)
You should install the above libraries in a [virtual environment](https://docs.python.org/3/library/venv.html).
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
to use and activate it.
You should be able to run the command:
```bash
python3 -m venv <your-venv-name>
```
If this doesn't work, you first might to have install `python3-venv`. You can do this as follows:
```bash
sudo apt-get install python3-venv
```
You can activate your venv by running
```bash
source ~/<your-venv-name>/bin/activate
```
Next you should install JAX's TPU version on TPU by running the following command:
```
$ pip install requests
......@@ -270,6 +396,24 @@ and then:
$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
**Note**: Running this command might actually throw an error, such as:
```
Building wheel for jax (setup.py) ... error
ERROR: Command errored out with exit status 1:
command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
cwd: /tmp/pip-install-lwseckn1/jax/
Complete output (6 lines):
usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
or: setup.py --help [cmd1 cmd2 ...]
or: setup.py --help-commands
or: setup.py cmd --help
error: invalid command 'bdist_wheel'
----------------------------------------
ERROR: Failed building wheel for jax
```
Jax should have been installed correctly nevertheless.
To verify that JAX was correctly installed, you can run the following command:
```python
......@@ -279,7 +423,9 @@ jax.device_count()
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
Now you can run the following steps as usual.
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
1. Fork the [repository](https://github.com/huggingface/transformers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
......@@ -352,7 +498,6 @@ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
model(input_ids)
```
## Quickstart flax and jax
[JAX](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment