Unverified Commit 390e121f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] TPU-based training of a language model using TensorFlow (#21657)



* add: tokenizer training script for TF TPU LM training.

* add: script for preparing the TFRecord shards.

* add: sequence of execution to readme.

* remove limit from the tfrecord shard name.

* Add initial train_model.py

* Add basic training arguments and model init

* Get up to the point of writing the data collator

* Pushing progress so far!

* Complete first draft of model training code

* feat: grouping of texts efficiently.
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>

* Add proper masking collator and get training loop working

* fix: things.

* Read sample counts from filenames

* Read sample counts from filenames

* Draft README

* Improve TPU warning

* Use distribute instead of distribute.experimental

* Apply suggestions from code review
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Modularize loading and add MLM probability as arg

* minor refactoring to better use the cli args.

* readme fillup.

* include tpu and inference sections in the readme.

* table of contents.

* parallelize maps.

* polish readme.

* change script name to run_mlm.py

* address PR feedback (round I).

---------
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent bfb3925f
# Training a masked language model end-to-end from scratch on TPUs
In this example, we're going to demonstrate how to train a TensorFlow model from 🤗 Transformers from scratch. If you're interested in some background theory on training Hugging Face models with TensorFlow on TPU, please check out our
[tutorial doc](https://huggingface.co/docs/transformers/main/perf_train_tpu_tf) on this topic!
If you're interested in smaller-scale TPU training from a pre-trained checkpoint, you can also check out the [TPU fine-tuning example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb).
This example will demonstrate pre-training language models at the 100M-1B parameter scale, similar to BERT or GPT-2. More concretely, we will show how to train a [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta) (base model) from scratch on the [WikiText dataset (v1)](https://huggingface.co/datasets/wikitext).
We've tried to ensure that all the practices we show you here are scalable, though - with relatively few changes, the code could be scaled up to much larger models.
Google's gargantuan [PaLM model](https://arxiv.org/abs/2204.02311), with
over 500B parameters, is a good example of how far you can go with pure TPU training, though gathering the dataset and the budget to train at that scale is not an easy task!
### Table of contents
- [Setting up a TPU-VM](#setting-up-a-tpu-vm)
- [Training a tokenizer](#training-a-tokenizer)
- [Preparing the dataset](#preparing-the-dataset)
- [Training the model](#training-the-model)
- [Inference](#inference)
## Setting up a TPU-VM
Since this example focuses on using TPUs, the first step is to set up access to TPU hardware. For this example, we chose to use a TPU v3-8 VM. Follow [this guide](https://cloud.google.com/tpu/docs/run-calculation-tensorflow) to quickly create a TPU VM with TensorFlow pre-installed.
> 💡 **Note**: You don't need a TPU-enabled hardware for tokenizer training and TFRecord shard preparation.
## Training a tokenizer
To train a language model from scratch, the first step is to tokenize text. In most Hugging Face examples, we begin from a pre-trained model and use its tokenizer. However, in this example, we're going to train a tokenizer from scratch as well. The script for this is `train_unigram.py`. An example command is:
```bash
python train_unigram.py --batch_size 1000 --vocab_size 25000 --export_to_hub
```
The script will automatically load the `train` split of the WikiText dataset and train a [Unigram tokenizer](https://huggingface.co/course/chapter6/7?fw=pt) on it.
> 💡 **Note**: In order for `export_to_hub` to work, you must authenticate yourself with the `huggingface-cli`. Run `huggingface-cli login` and follow the on-screen instructions.
## Preparing the dataset
The next step is to prepare the dataset. This consists of loading a text dataset from the Hugging Face Hub, tokenizing it and grouping it into chunks of a fixed length ready for training. The script for this is `prepare_tfrecord_shards.py`.
The reason we create TFRecord output files from this step is that these files work well with [`tf.data` pipelines](https://www.tensorflow.org/guide/data_performance). This makes them very suitable for scalable TPU training - the dataset can easily be sharded and read in parallel just by tweaking a few parameters in the pipeline. An example command is:
```bash
python prepare_tfrecord_shards.py \
--tokenizer_name_or_path tf-tpu/unigram-tokenizer-wikitext \
--shard_size 5000 \
--split test
--max_length 128 \
--output_dir gs://tf-tpu-training-resources
```
**Notes**:
* While running the above script, you need to specify the `split` accordingly. The example command above will only filter the `test` split of the dataset.
* If you append `gs://` in your `output_dir` the TFRecord shards will be directly serialized to a Google Cloud Storage (GCS) bucket. Ensure that you have already [created the GCS bucket](https://cloud.google.com/storage/docs).
* If you're using a TPU node, you must stream data from a GCS bucket. Otherwise, if you're using a TPU VM,you can store the data locally. You may need to [attach](https://cloud.google.com/tpu/docs/setup-persistent-disk) a persistent storage to the VM.
* Additional CLI arguments are also supported. We encourage you to run `python prepare_tfrecord_shards.py -h` to know more about them.
## Training the model
Once that's done, the model is ready for training. By default, training takes place on TPU, but you can use the `--no_tpu` flag to train on CPU for testing purposes. An example command is:
```bash
python3 run_mlm.py \
--train_dataset gs://tf-tpu-training-resources/train/ \
--eval_dataset gs://tf-tpu-training-resources/validation/ \
--tokenizer tf-tpu/unigram-tokenizer-wikitext \
--output_dir trained_model
```
If you had specified a `hub_model_id` while launching training, then your model will be pushed to a model repository on the Hugging Face Hub. You can find such an example repository here:
[tf-tpu/roberta-base-epochs-500-no-wd](https://huggingface.co/tf-tpu/roberta-base-epochs-500-no-wd).
## Inference
Once the model is trained, you can use 🤗 Pipelines to perform inference:
```python
from transformers import pipeline
model_id = "tf-tpu/roberta-base-epochs-500-no-wd"
unmasker = pipeline("fill-mask", model=model_id, framework="tf")
unmasker("Goal of my life is to [MASK].")
[{'score': 0.1003185287117958,
'token': 52,
'token_str': 'be',
'sequence': 'Goal of my life is to be.'},
{'score': 0.032648514956235886,
'token': 5,
'token_str': '',
'sequence': 'Goal of my life is to .'},
{'score': 0.02152673341333866,
'token': 138,
'token_str': 'work',
'sequence': 'Goal of my life is to work.'},
{'score': 0.019547373056411743,
'token': 984,
'token_str': 'act',
'sequence': 'Goal of my life is to act.'},
{'score': 0.01939118467271328,
'token': 73,
'token_str': 'have',
'sequence': 'Goal of my life is to have.'}]
```
You can also try out inference using the [Inference Widget](https://huggingface.co/tf-tpu/roberta-base-epochs-500-no-wd?text=Goal+of+my+life+is+to+%5BMASK%5D.) from the model page.
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for preparing TFRecord shards for pre-tokenized examples."""
import argparse
import logging
import os
import datasets
import tensorflow as tf
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset."
)
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
default="sayakpaul/unigram-tokenizer-wikitext",
help="Tokenizer identifier. Can be a local filepath or a Hub identifier.",
)
parser.add_argument(
"--shard_size",
type=int,
default=1000,
help="Number of entries to go in a single shard.",
)
parser.add_argument("--split", type=str, default="train", choices=["train", "test", "validation"])
parser.add_argument(
"--limit",
default=None,
type=int,
help="Limit the number of shards (used for debugging).",
)
parser.add_argument(
"--max_length",
type=int,
default=512,
help="Maximum sequence length. For training on TPUs, it helps to have a maximum"
" sequence length that is a multiple of 8.",
)
parser.add_argument(
"--output_dir",
default="tf-tpu",
type=str,
help="Output directory where the TFRecord shards will be saved. If the"
" path is appended with `gs://` ('gs://tf-tpu', for example) then the TFRecord"
" shards will be directly saved to a Google Cloud Storage bucket.",
)
args = parser.parse_args()
return args
def tokenize_function(tokenizer):
def fn(examples):
return tokenizer(examples["text"])
return fn
def get_serialized_examples(tokenized_data):
records = []
for i in range(len(tokenized_data["input_ids"])):
features = {
"input_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=tokenized_data["input_ids"][i])),
"attention_mask": tf.train.Feature(
int64_list=tf.train.Int64List(value=tokenized_data["attention_mask"][i])
),
}
features = tf.train.Features(feature=features)
example = tf.train.Example(features=features)
record_bytes = example.SerializeToString()
records.append(record_bytes)
return records
def main(args):
wikitext = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split=args.split)
if args.limit is not None:
max_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_samples))
print(f"Limiting the dataset to {args.limit} entries.")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
# Handle output directory creation.
# For serializing into a Google Cloud Storage Bucket, one needs to first
# create a bucket.
if "gs" not in args.output_dir:
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
split_dir = os.path.join(args.output_dir, args.split)
if not os.path.exists(split_dir):
os.makedirs(split_dir)
else:
split_dir = os.path.join(args.output_dir, args.split)
# Tokenize the whole dataset at once.
tokenize_fn = tokenize_function(tokenizer)
wikitext_tokenized = wikitext.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])
# We need to concatenate all our texts together, and then split the result
# into chunks of a fixed size, which we will call block_size. To do this, we
# will use the map method again, with the option batched=True. When we use batched=True,
# the function we pass to map() will be passed multiple inputs at once, allowing us
# to group them into more or fewer examples than we had in the input.
# This allows us to create our new fixed-length samples. The advantage of this
# method is that we don't lose a whole lot of content from the dataset compared to the
# case where we simply tokenize with a pre-defined max_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, though you could add padding instead if the model supports it
# In this, as in all things, we advise you to follow your heart 🫀
total_length = (total_length // args.max_length) * args.max_length
# Split by chunks of max_len.
result = {
k: [t[i : i + args.max_length] for i in range(0, total_length, args.max_length)]
for k, t in concatenated_examples.items()
}
return result
grouped_dataset = wikitext_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)
shard_count = 0
total_records = 0
for shard in range(0, len(grouped_dataset), args.shard_size):
dataset_snapshot = grouped_dataset[shard : shard + args.shard_size]
records_containing = len(dataset_snapshot["input_ids"])
filename = os.path.join(split_dir, f"wikitext-{shard_count}-{records_containing}.tfrecord")
serialized_examples = get_serialized_examples(dataset_snapshot)
with tf.io.TFRecordWriter(filename) as out_file:
for i in range(len(serialized_examples)):
example = serialized_examples[i]
out_file.write(example)
print("Wrote file {} containing {} records".format(filename, records_containing))
shard_count += 1
total_records += records_containing
with open(f"split-{args.split}-records-count.txt", "w") as f:
print(f"Total {args.split} records: {total_records}", file=f)
if __name__ == "__main__":
args = parse_args()
main(args)
transformers==4.26.1
datasets==2.9.0
tokenizers==0.13.2
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a masked language model on TPU."""
import argparse
import logging
import os
import re
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
DataCollatorForLanguageModeling,
PushToHubCallback,
TFAutoModelForMaskedLM,
create_optimizer,
)
logger = logging.getLogger(__name__)
AUTO = tf.data.AUTOTUNE
def parse_args():
parser = argparse.ArgumentParser(description="Train a masked language model on TPU.")
parser.add_argument(
"--pretrained_model_config",
type=str,
default="roberta-base",
help="The model config to use. Note that we don't copy the model's weights, only the config!",
)
parser.add_argument(
"--tokenizer",
type=str,
default="unigram-tokenizer-wikitext",
help="The name of the tokenizer to load. We use the pretrained tokenizer to initialize the model's vocab size.",
)
parser.add_argument(
"--per_replica_batch_size",
type=int,
default=8,
help="Batch size per TPU core.",
)
parser.add_argument(
"--no_tpu",
action="store_true",
help="If set, run on CPU and don't try to initialize a TPU. Useful for debugging on non-TPU instances.",
)
parser.add_argument(
"--tpu_name",
type=str,
help="Name of TPU resource to initialize. Should be blank on Colab, and 'local' on TPU VMs.",
default="local",
)
parser.add_argument(
"--tpu_zone",
type=str,
help="Google cloud zone that TPU resource is located in. Only used for non-Colab TPU nodes.",
)
parser.add_argument(
"--gcp_project", type=str, help="Google cloud project name. Only used for non-Colab TPU nodes."
)
parser.add_argument(
"--bfloat16",
action="store_true",
help="Use mixed-precision bfloat16 for training. This is the recommended lower-precision format for TPU.",
)
parser.add_argument(
"--train_dataset",
type=str,
help="Path to training dataset to load. If the path begins with `gs://`"
" then the dataset will be loaded from a Google Cloud Storage bucket.",
)
parser.add_argument(
"--shuffle_buffer_size",
type=int,
default=2**18, # Default corresponds to a 1GB buffer for seq_len 512
help="Size of the shuffle buffer (in samples)",
)
parser.add_argument(
"--eval_dataset",
type=str,
help="Path to evaluation dataset to load. If the path begins with `gs://`"
" then the dataset will be loaded from a Google Cloud Storage bucket.",
)
parser.add_argument(
"--num_epochs",
type=int,
default=1,
help="Number of epochs to train for.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use for training.",
)
parser.add_argument(
"--weight_decay_rate",
type=float,
default=1e-3,
help="Weight decay rate to use for training.",
)
parser.add_argument(
"--max_length",
type=int,
default=512,
help="Maximum length of tokenized sequences. Should match the setting used in prepare_tfrecord_shards.py",
)
parser.add_argument(
"--mlm_probability",
type=float,
default=0.15,
help="Fraction of tokens to mask during training.",
)
parser.add_argument("--output_dir", type=str, required=True, help="Path to save model checkpoints to.")
parser.add_argument("--hub_model_id", type=str, help="Model ID to upload to on the Hugging Face Hub.")
args = parser.parse_args()
return args
def initialize_tpu(args):
try:
if args.tpu_name:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
args.tpu_name, zone=args.tpu_zone, project=args.gcp_project
)
else:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
raise RuntimeError(
"Couldn't connect to TPU! Most likely you need to specify --tpu_name, --tpu_zone, or "
"--gcp_project. When running on a TPU VM, use --tpu_name local."
)
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
return tpu
def count_samples(file_list):
num_samples = 0
for file in file_list:
filename = file.split("/")[-1]
sample_count = re.search(r"-\d+-(\d+)\.tfrecord", filename).group(1)
sample_count = int(sample_count)
num_samples += sample_count
return num_samples
def prepare_dataset(records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None):
num_samples = count_samples(records)
dataset = tf.data.Dataset.from_tensor_slices(records)
if shuffle:
dataset = dataset.shuffle(len(dataset))
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=AUTO)
# TF can't infer the total sample count because it doesn't read all the records yet, so we assert it here
dataset = dataset.apply(tf.data.experimental.assert_cardinality(num_samples))
dataset = dataset.map(decode_fn, num_parallel_calls=AUTO)
if shuffle:
assert shuffle_buffer_size is not None
dataset = dataset.shuffle(args.shuffle_buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(mask_fn, num_parallel_calls=AUTO)
dataset = dataset.prefetch(AUTO)
return dataset
def main(args):
if not args.no_tpu:
tpu = initialize_tpu(args)
strategy = tf.distribute.TPUStrategy(tpu)
else:
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
if args.bfloat16:
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
config = AutoConfig.from_pretrained(args.pretrained_model_config)
config.vocab_size = tokenizer.vocab_size
training_records = tf.io.gfile.glob(os.path.join(args.train_dataset, "*.tfrecord"))
if not training_records:
raise ValueError(f"No .tfrecord files found in {args.train_dataset}.")
eval_records = tf.io.gfile.glob(os.path.join(args.eval_dataset, "*.tfrecord"))
if not eval_records:
raise ValueError(f"No .tfrecord files found in {args.eval_dataset}.")
num_train_samples = count_samples(training_records)
steps_per_epoch = num_train_samples // (args.per_replica_batch_size * strategy.num_replicas_in_sync)
total_train_steps = steps_per_epoch * args.num_epochs
with strategy.scope():
model = TFAutoModelForMaskedLM.from_config(config)
model(model.dummy_inputs) # Pass some dummy inputs through the model to ensure all the weights are built
optimizer, schedule = create_optimizer(
num_train_steps=total_train_steps,
num_warmup_steps=total_train_steps // 20,
init_lr=args.learning_rate,
weight_decay_rate=args.weight_decay_rate,
# TODO Add the other Adam parameters?
)
model.compile(optimizer=optimizer, metrics=["accuracy"])
def decode_fn(example):
features = {
"input_ids": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)),
"attention_mask": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)),
}
return tf.io.parse_single_example(example, features)
# Many of the data collators in Transformers are TF-compilable when return_tensors == "tf", so we can
# use their methods in our data pipeline.
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm_probability=args.mlm_probability, mlm=True, return_tensors="tf"
)
def mask_with_collator(batch):
# TF really needs an isin() function
special_tokens_mask = (
~tf.cast(batch["attention_mask"], tf.bool)
| (batch["input_ids"] == tokenizer.cls_token_id)
| (batch["input_ids"] == tokenizer.sep_token_id)
)
batch["input_ids"], batch["labels"] = data_collator.tf_mask_tokens(
batch["input_ids"],
vocab_size=len(tokenizer),
mask_token_id=tokenizer.mask_token_id,
special_tokens_mask=special_tokens_mask,
)
return batch
batch_size = args.per_replica_batch_size * strategy.num_replicas_in_sync
train_dataset = prepare_dataset(
training_records,
decode_fn=decode_fn,
mask_fn=mask_with_collator,
batch_size=batch_size,
shuffle=True,
shuffle_buffer_size=args.shuffle_buffer_size,
)
eval_dataset = prepare_dataset(
eval_records,
decode_fn=decode_fn,
mask_fn=mask_with_collator,
batch_size=batch_size,
shuffle=False,
)
callbacks = []
if args.hub_model_id:
callbacks.append(
PushToHubCallback(output_dir=args.output_dir, hub_model_id=args.hub_model_id, tokenizer=tokenizer)
)
model.fit(
train_dataset,
validation_data=eval_dataset,
epochs=args.num_epochs,
callbacks=callbacks,
)
model.save_pretrained(args.output_dir)
if __name__ == "__main__":
args = parse_args()
main(args)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a Unigram tokenizer."""
import argparse
import logging
import datasets
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from transformers import AlbertTokenizerFast
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.")
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="Name of the training. Explore datasets at: hf.co/datasets.",
)
parser.add_argument(
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
)
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="Batch size during training.",
)
parser.add_argument(
"--vocab_size",
type=int,
default=10048,
help="Size of the desired vocabulary.",
)
parser.add_argument(
"--limit",
default=None,
type=int,
help="Limit the number of shards (used for debugging).",
)
parser.add_argument(
"--export_to_hub",
action="store_true",
)
args = parser.parse_args()
return args
def main(args):
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
if args.limit is not None:
max_train_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")
def batch_iterator():
for i in range(0, len(wikitext), args.batch_size):
yield wikitext[i : i + args.batch_size]["text"]
# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')])
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
# Prepare the trainer.
trainer = UnigramTrainer(
unk_token="<unk>",
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
vocab_size=args.vocab_size,
)
logger.info("Training the tokenizer.")
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
logger.info("Tokenizer training complete!")
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", cls_token_id),
("[SEP]", sep_token_id),
],
)
tokenizer.decoder = decoders.Metaspace()
if args.export_to_hub:
logger.info("Exporting the trained tokenzier to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
if __name__ == "__main__":
args = parse_args()
main(args)
......@@ -664,6 +664,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
"""
import tensorflow as tf
mask_token_id = tf.cast(mask_token_id, inputs.dtype)
input_shape = tf.shape(inputs)
# 1 for a special token, 0 for a normal token in the special tokens mask
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
......@@ -677,8 +679,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
inputs = tf.where(indices_replaced, mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64)
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
inputs = tf.where(indices_random, random_words, inputs)
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment