Unverified Commit d2753dcb authored by Bhavitvya Malik's avatar Bhavitvya Malik Committed by GitHub
Browse files

add relevant description to tqdm in examples (#11927)

* add relevant `desc` in examples

* require_version datasets>=1.8.0
parent 9a9314f6
accelerate accelerate
datasets >= 1.1.3 datasets >= 1.8.0
sentencepiece != 0.1.92 sentencepiece != 0.1.92
protobuf protobuf
torch >= 1.3 torch >= 1.3
...@@ -42,10 +42,12 @@ from transformers import ( ...@@ -42,10 +42,12 @@ from transformers import (
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.7.0.dev0") check_min_version("4.7.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
task_to_keys = { task_to_keys = {
"cola": ("sentence", None), "cola": ("sentence", None),
...@@ -393,7 +395,12 @@ def main(): ...@@ -393,7 +395,12 @@ def main():
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) datasets = datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if training_args.do_train: if training_args.do_train:
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
......
...@@ -38,10 +38,13 @@ from transformers import ( ...@@ -38,10 +38,13 @@ from transformers import (
get_scheduler, get_scheduler,
set_seed, set_seed,
) )
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
task_to_keys = { task_to_keys = {
"cola": ("sentence", None), "cola": ("sentence", None),
"mnli": ("premise", "hypothesis"), "mnli": ("premise", "hypothesis"),
...@@ -305,7 +308,10 @@ def main(): ...@@ -305,7 +308,10 @@ def main():
return result return result
processed_datasets = raw_datasets.map( processed_datasets = raw_datasets.map(
preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names preprocess_function,
batched=True,
remove_columns=raw_datasets["train"].column_names,
desc="Running tokenizer on dataset",
) )
train_dataset = processed_datasets["train"] train_dataset = processed_datasets["train"]
......
...@@ -42,10 +42,12 @@ from transformers import ( ...@@ -42,10 +42,12 @@ from transformers import (
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.7.0.dev0") check_min_version("4.7.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -280,6 +282,7 @@ def main(): ...@@ -280,6 +282,7 @@ def main():
preprocess_function, preprocess_function,
batched=True, batched=True,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
) )
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
...@@ -292,6 +295,7 @@ def main(): ...@@ -292,6 +295,7 @@ def main():
preprocess_function, preprocess_function,
batched=True, batched=True,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
) )
if training_args.do_predict: if training_args.do_predict:
...@@ -301,6 +305,7 @@ def main(): ...@@ -301,6 +305,7 @@ def main():
preprocess_function, preprocess_function,
batched=True, batched=True,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
) )
# Get the metric function # Get the metric function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment