"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6d4306160ab22c54a677e07f88c7d8808b137d38"
Unverified Commit 7db2a79b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/flax] use Repository API for push_to_hub (#13672)

* use Repository for push_to_hub

* update readme

* update other flax scripts

* update readme

* update qa example

* fix push_to_hub call

* fix typo

* fix more typos

* update readme

* use abosolute path to get repo name

* fix glue script
parent b90096fe
...@@ -61,3 +61,14 @@ For a complete overview of models that are supported in JAX/Flax, please have a ...@@ -61,3 +61,14 @@ For a complete overview of models that are supported in JAX/Flax, please have a
Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021. Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021.
Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub. Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub.
## Upload the trained/fine-tuned model to the Hub
All the example scripts support automatic upload of your final model to the [Model Hub](https://huggingface.co/models) by adding a `--push_to_hub` argument. It will then create a repository with your username slash the name of the folder you are using as `output_dir`. For instance, `"sgugger/test-mrpc"` if your username is `sgugger` and you are working in the folder `~/tmp/test-mrpc`.
To specify a given repository name, use the `--hub_model_id` argument. You will need to specify the whole repository name (including your username), for instance `--hub_model_id sgugger/finetuned-bert-mrpc`. To upload to an organization you are a member of, just use the name of that organization instead of your username: `--hub_model_id huggingface/finetuned-bert-mrpc`.
A few notes on this integration:
- you will need to be logged in to the Hugging Face website locally for it to work, the easiest way to achieve this is to run `huggingface-cli login` and then type your username and password when prompted. You can also pass along your authentication token with the `--hub_token` argument.
- the `output_dir` you pick will either need to be a new folder or a local clone of the distant repository you are using.
...@@ -33,32 +33,10 @@ in Norwegian on a single TPUv3-8 pod. ...@@ -33,32 +33,10 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs. To setup all relevant files for training, let's create a directory.
Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-roberta-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
```
To setup all relevant files for training, let's go into the cloned model directory.
```bash ```bash
cd norwegian-roberta-base mkdir ./norwegian-roberta-base
```
Next, let's add a symbolic link to the `run_mlm_flax.py`.
```bash
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
``` ```
### Train tokenizer ### Train tokenizer
...@@ -92,7 +70,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency= ...@@ -92,7 +70,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
]) ])
# Save files to disk # Save files to disk
tokenizer.save("./tokenizer.json") tokenizer.save("./norwegian-roberta-base/tokenizer.json")
``` ```
### Create configuration ### Create configuration
...@@ -105,7 +83,7 @@ in the local model folder: ...@@ -105,7 +83,7 @@ in the local model folder:
from transformers import RobertaConfig from transformers import RobertaConfig
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265) config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
config.save_pretrained("./") config.save_pretrained("./norwegian-roberta-base")
``` ```
Great, we have set up our model repository. During training, we will automatically Great, we have set up our model repository. During training, we will automatically
...@@ -116,11 +94,11 @@ push the training logs and model weights to the repo. ...@@ -116,11 +94,11 @@ push the training logs and model weights to the repo.
Next we can run the example script to pretrain the model: Next we can run the example script to pretrain the model:
```bash ```bash
./run_mlm_flax.py \ python run_mlm_flax.py \
--output_dir="./" \ --output_dir="./norwegian-roberta-base" \
--model_type="roberta" \ --model_type="roberta" \
--config_name="./" \ --config_name="./norwegian-roberta-base" \
--tokenizer_name="./" \ --tokenizer_name="./norwegian-roberta-base" \
--dataset_name="oscar" \ --dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \ --dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \ --max_seq_length="128" \
...@@ -157,32 +135,11 @@ in Norwegian on a single TPUv3-8 pod. ...@@ -157,32 +135,11 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
``` To setup all relevant files for training, let's create a directory.
huggingface-cli repo create norwegian-gpt2
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-gpt2
```
To setup all relevant files for training, let's go into the cloned model directory.
```bash
cd norwegian-gpt2
```
Next, let's add a symbolic link to the training script `run_clm_flax.py`.
```bash ```bash
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py mkdir ./norwegian-gpt2
``` ```
### Train tokenizer ### Train tokenizer
...@@ -216,7 +173,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency= ...@@ -216,7 +173,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=
]) ])
# Save files to disk # Save files to disk
tokenizer.save("./tokenizer.json") tokenizer.save("./norwegian-gpt2/tokenizer.json")
``` ```
### Create configuration ### Create configuration
...@@ -229,7 +186,7 @@ in the local model folder: ...@@ -229,7 +186,7 @@ in the local model folder:
from transformers import GPT2Config from transformers import GPT2Config
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257) config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
config.save_pretrained("./") config.save_pretrained("./norwegian-gpt2")
``` ```
Great, we have set up our model repository. During training, we will now automatically Great, we have set up our model repository. During training, we will now automatically
...@@ -240,11 +197,11 @@ push the training logs and model weights to the repo. ...@@ -240,11 +197,11 @@ push the training logs and model weights to the repo.
Finally, we can run the example script to pretrain the model: Finally, we can run the example script to pretrain the model:
```bash ```bash
./run_clm_flax.py \ python run_clm_flax.py \
--output_dir="./" \ --output_dir="./norwegian-gpt2" \
--model_type="gpt2" \ --model_type="gpt2" \
--config_name="./" \ --config_name="./norwegian-gpt2" \
--tokenizer_name="./" \ --tokenizer_name="./norwegian-gpt2" \
--dataset_name="oscar" \ --dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \ --dataset_config_name="unshuffled_deduplicated_no" \
--do_train --do_eval \ --do_train --do_eval \
...@@ -282,30 +239,10 @@ The example script uses the 🤗 Datasets library. You can easily customize them ...@@ -282,30 +239,10 @@ The example script uses the 🤗 Datasets library. You can easily customize them
Let's start by creating a model repository to save the trained model and logs. Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like. Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that To setup all relevant files for trairing, let's create a directory.
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-t5-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-t5-base
```
To setup all relevant files for trairing, let's go into the cloned model directory.
```bash
cd norwegian-t5-base
```
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
```bash ```bash
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py cd ./norwegian-t5-base
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
``` ```
### Train tokenizer ### Train tokenizer
...@@ -351,7 +288,7 @@ tokenizer.train_from_iterator( ...@@ -351,7 +288,7 @@ tokenizer.train_from_iterator(
) )
# Save files to disk # Save files to disk
tokenizer.save("./tokenizer.json") tokenizer.save("./norwegian-t5-base/tokenizer.json")
``` ```
### Create configuration ### Create configuration
...@@ -364,7 +301,7 @@ in the local model folder: ...@@ -364,7 +301,7 @@ in the local model folder:
from transformers import T5Config from transformers import T5Config
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size()) config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained("./") config.save_pretrained("./norwegian-t5-base")
``` ```
Great, we have set up our model repository. During training, we will automatically Great, we have set up our model repository. During training, we will automatically
...@@ -375,11 +312,11 @@ push the training logs and model weights to the repo. ...@@ -375,11 +312,11 @@ push the training logs and model weights to the repo.
Next we can run the example script to pretrain the model: Next we can run the example script to pretrain the model:
```bash ```bash
./run_t5_mlm_flax.py \ python run_t5_mlm_flax.py \
--output_dir="./" \ --output_dir="./norwegian-t5-base" \
--model_type="t5" \ --model_type="t5" \
--config_name="./" \ --config_name="./norwegian-t5-base" \
--tokenizer_name="./" \ --tokenizer_name="./norwegian-t5-base" \
--dataset_name="oscar" \ --dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \ --dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="512" \ --max_seq_length="512" \
......
...@@ -43,6 +43,7 @@ from flax import jax_utils, traverse_util ...@@ -43,6 +43,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
...@@ -54,6 +55,7 @@ from transformers import ( ...@@ -54,6 +55,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.file_utils import get_full_repo_name
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
...@@ -275,6 +277,16 @@ def main(): ...@@ -275,6 +277,16 @@ def main():
# Set seed before initializing model. # Set seed before initializing model.
set_seed(training_args.seed) set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
...@@ -654,12 +666,10 @@ def main(): ...@@ -654,12 +666,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(unreplicate(state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
commit_message=f"Saving weights and logs of step {cur_step}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -41,6 +41,7 @@ import optax ...@@ -41,6 +41,7 @@ import optax
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
...@@ -54,6 +55,7 @@ from transformers import ( ...@@ -54,6 +55,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.file_utils import get_full_repo_name
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
...@@ -308,6 +310,16 @@ if __name__ == "__main__": ...@@ -308,6 +310,16 @@ if __name__ == "__main__":
# Set seed before initializing model. # Set seed before initializing model.
set_seed(training_args.seed) set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
...@@ -683,9 +695,7 @@ if __name__ == "__main__": ...@@ -683,9 +695,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
commit_message=f"Saving weights and logs of step {cur_step}",
)
...@@ -39,6 +39,7 @@ import optax ...@@ -39,6 +39,7 @@ import optax
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
...@@ -52,6 +53,7 @@ from transformers import ( ...@@ -52,6 +53,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.file_utils import get_full_repo_name
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
...@@ -438,6 +440,16 @@ if __name__ == "__main__": ...@@ -438,6 +440,16 @@ if __name__ == "__main__":
# Set seed before initializing model. # Set seed before initializing model.
set_seed(training_args.seed) set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
...@@ -791,9 +803,7 @@ if __name__ == "__main__": ...@@ -791,9 +803,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
commit_message=f"Saving weights and logs of step {cur_step}",
)
...@@ -26,31 +26,6 @@ of the script. ...@@ -26,31 +26,6 @@ of the script.
The following example fine-tunes BERT on SQuAD: The following example fine-tunes BERT on SQuAD:
To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-qa-squad-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-qa-squad-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-qa-squad-test
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_qa.py`.
```bash
export MODEL_DIR="./bert-qa-squad-test"
ln -s ~/transformers/examples/flax/question-answering/run_qa.py run_qa.py
```
```bash ```bash
python run_qa.py \ python run_qa.py \
...@@ -63,7 +38,7 @@ python run_qa.py \ ...@@ -63,7 +38,7 @@ python run_qa.py \
--learning_rate 3e-5 \ --learning_rate 3e-5 \
--num_train_epochs 2 \ --num_train_epochs 2 \
--per_device_train_batch_size 12 \ --per_device_train_batch_size 12 \
--output_dir ${MODEL_DIR} \ --output_dir ./bert-qa-squad \
--eval_steps 1000 \ --eval_steps 1000 \
--push_to_hub --push_to_hub
``` ```
...@@ -101,8 +76,9 @@ python run_qa.py \ ...@@ -101,8 +76,9 @@ python run_qa.py \
--num_train_epochs 2 \ --num_train_epochs 2 \
--max_seq_length 384 \ --max_seq_length 384 \
--doc_stride 128 \ --doc_stride 128 \
--output_dir /tmp/wwm_uncased_finetuned_squad/ \ --output_dir ./wwm_uncased_finetuned_squad/ \
--eval_steps 1000 --eval_steps 1000 \
--push_to_hub
``` ```
Training with the previously defined hyper-parameters yields the following results: Training with the previously defined hyper-parameters yields the following results:
......
...@@ -25,6 +25,7 @@ import sys ...@@ -25,6 +25,7 @@ import sys
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
...@@ -41,6 +42,7 @@ from flax.jax_utils import replicate, unreplicate ...@@ -41,6 +42,7 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard from flax.metrics import tensorboard
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -50,6 +52,7 @@ from transformers import ( ...@@ -50,6 +52,7 @@ from transformers import (
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
TrainingArguments, TrainingArguments,
) )
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version from transformers.utils import check_min_version
from utils_qa import postprocess_qa_predictions from utils_qa import postprocess_qa_predictions
...@@ -359,6 +362,16 @@ def main(): ...@@ -359,6 +362,16 @@ def main():
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# endregion # endregion
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# region Load Data # region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
...@@ -891,12 +904,10 @@ def main(): ...@@ -891,12 +904,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(unreplicate(state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
commit_message=f"Saving weights and logs of step {cur_step}",
)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion # endregion
......
...@@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism. ...@@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below. For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"bart-base-xsum"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bart-base-xsum
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bart-base-xsum
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd bart-base-xsum
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_summarization_flax.py`.
```bash
export MODEL_DIR="./bart-base-xsum"
ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py
```
### Train the model ### Train the model
Next we can run the example script to train the model: Next we can run the example script to train the model:
```bash ```bash
python run_summarization_flax.py \ python run_summarization_flax.py \
--output_dir ${MODEL_DIR} \ --output_dir ./bart-base-xsum \
--model_name_or_path facebook/bart-base \ --model_name_or_path facebook/bart-base \
--tokenizer_name facebook/bart-base \ --tokenizer_name facebook/bart-base \
--dataset_name="xsum" \ --dataset_name="xsum" \
......
...@@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util ...@@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
...@@ -52,7 +53,7 @@ from transformers import ( ...@@ -52,7 +53,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.file_utils import is_offline_mode from transformers.file_utils import get_full_repo_name, is_offline_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -333,6 +334,16 @@ def main(): ...@@ -333,6 +334,16 @@ def main():
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
...@@ -800,12 +811,10 @@ def main(): ...@@ -800,12 +811,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,47 +21,15 @@ limitations under the License. ...@@ -21,47 +21,15 @@ limitations under the License.
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py). Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models). Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a
dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case,
refer to the comments inside for help).
To begin with it is recommended to create a model repository to save the trained model and logs. GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-glue-mrpc-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd bert-glue-mrpc-test
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_flax_glue.py`.
```bash ```bash
export TASK_NAME=mrpc export TASK_NAME=mrpc
export MODEL_DIR="./bert-glue-mrpc-test"
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
```
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
```bash
python run_flax_glue.py \ python run_flax_glue.py \
--model_name_or_path bert-base-cased \ --model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
...@@ -69,7 +37,7 @@ python run_flax_glue.py \ ...@@ -69,7 +37,7 @@ python run_flax_glue.py \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3 \ --num_train_epochs 3 \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--output_dir ${MODEL_DIR} \ --output_dir ./$TASK_NAME/ \
--push_to_hub --push_to_hub
``` ```
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import random import random
import time import time
from itertools import chain from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import datasets import datasets
...@@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate ...@@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard from flax.metrics import tensorboard
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
from transformers.file_utils import get_full_repo_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -128,6 +131,10 @@ def parse_args(): ...@@ -128,6 +131,10 @@ def parse_args():
action="store_true", action="store_true",
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub", help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
) )
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -141,6 +148,9 @@ def parse_args(): ...@@ -141,6 +148,9 @@ def parse_args():
extension = args.validation_file.split(".")[-1] extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
if args.output_dir is not None: if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
...@@ -267,6 +277,14 @@ def main(): ...@@ -267,6 +277,14 @@ def main():
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
...@@ -499,12 +517,10 @@ def main(): ...@@ -499,12 +517,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained( model.save_pretrained(args.output_dir, params=params)
args.output_dir, tokenizer.save_pretrained(args.output_dir)
params=params, if args.push_to_hub:
push_to_hub=args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
commit_message=f"Saving weights and logs of epoch {epoch}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -22,31 +22,6 @@ It will either run on a datasets hosted on our hub or with your own text files f ...@@ -22,31 +22,6 @@ It will either run on a datasets hosted on our hub or with your own text files f
The following example fine-tunes BERT on CoNLL-2003: The following example fine-tunes BERT on CoNLL-2003:
To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-ner-conll2003-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-ner-conll2003-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-ner-conll2003-test
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_flax_ner.py`.
```bash
export MODEL_DIR="./bert-ner-conll2003-test"
ln -s ~/transformers/examples/flax/token-classification/run_flax_ner.py run_flax_ner.py
```
```bash ```bash
python run_flax_ner.py \ python run_flax_ner.py \
...@@ -56,7 +31,7 @@ python run_flax_ner.py \ ...@@ -56,7 +31,7 @@ python run_flax_ner.py \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3 \ --num_train_epochs 3 \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--output_dir ${MODEL_DIR} \ --output_dir ./bert-ner-conll2003 \
--eval_steps 300 \ --eval_steps 300 \
--push_to_hub --push_to_hub
``` ```
......
...@@ -21,6 +21,7 @@ import sys ...@@ -21,6 +21,7 @@ import sys
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
...@@ -37,6 +38,7 @@ from flax.jax_utils import replicate, unreplicate ...@@ -37,6 +38,7 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard from flax.metrics import tensorboard
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -44,6 +46,7 @@ from transformers import ( ...@@ -44,6 +46,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
TrainingArguments, TrainingArguments,
) )
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -304,6 +307,16 @@ def main(): ...@@ -304,6 +307,16 @@ def main():
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
...@@ -656,12 +669,10 @@ def main(): ...@@ -656,12 +669,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(unreplicate(state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, tokenizer.save_pretrained(training_args.output_dir)
params=params, if training_args.push_to_hub:
push_to_hub=training_args.push_to_hub, repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
commit_message=f"Saving weights and logs of step {cur_step}",
)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
......
...@@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism. ...@@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism.
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset. In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create vit-base-patch16-imagenette
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd vit-base-patch16-imagenette
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
```bash
export MODEL_DIR="./vit-base-patch16-imagenette
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
```
## Prepare the dataset ## Prepare the dataset
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
...@@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model: ...@@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model:
```bash ```bash
python run_image_classification.py \ python run_image_classification.py \
--output_dir ${MODEL_DIR} \ --output_dir ./vit-base-patch16-imagenette \
--model_name_or_path google/vit-base-patch16-224-in21k \ --model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir="imagenette2/train" \ --train_dir="imagenette2/train" \
--validation_dir="imagenette2/val" \ --validation_dir="imagenette2/val" \
......
...@@ -42,6 +42,7 @@ from flax import jax_utils ...@@ -42,6 +42,7 @@ from flax import jax_utils
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
...@@ -52,6 +53,7 @@ from transformers import ( ...@@ -52,6 +53,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.file_utils import get_full_repo_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -205,6 +207,16 @@ def main(): ...@@ -205,6 +207,16 @@ def main():
# set seed for random transforms and torch dataloaders # set seed for random transforms and torch dataloaders
set_seed(training_args.seed) set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Initialize datasets and pre-processing transforms # Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing # We use torchvision here for faster pre-processing
# Note that here we are using some default pre-processing, for maximum accuray # Note that here we are using some default pre-processing, for maximum accuray
...@@ -455,12 +467,9 @@ def main(): ...@@ -455,12 +467,9 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained( model.save_pretrained(training_args.output_dir, params=params)
training_args.output_dir, if training_args.push_to_hub:
params=params, repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment