Unverified Commit d5b8fe3b authored by Souvic Chakraborty's avatar Souvic Chakraborty Committed by GitHub
Browse files

Validation split added: custom data files @sgugger, @patil-suraj (#12407)



* Validation split added: custom data files

Validation split added in case of no validation file and loading custom data

* Updated documentation with custom file usage

Updated documentation with custom file usage

* Update README.md

* Update README.md

* Update README.md

* Made some suggested stylistic changes

* Used logger instead of print.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Made similar changes to add validation split

In case of a missing validation file, a validation split will be used now.

* max_train_samples to be used for training only

max_train_samples got misplaced, now corrected so that it is applied on training data only, not whole data.

* styled

* changed ordering

* Improved language of documentation
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improved language of documentation
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fixed styling issue

* Update run_mlm.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent f929462b
...@@ -49,6 +49,14 @@ python run_mlm.py \ ...@@ -49,6 +49,14 @@ python run_mlm.py \
--dataset_config_name wikitext-103-raw-v1 --dataset_config_name wikitext-103-raw-v1
``` ```
When using a custom dataset, the validation file can be separately passed as an input argument. Otherwise some split (customizable) of training data is used as validation.
```
python run_mlm.py \
--model_name_or_path distilbert-base-cased \
--output_dir output \
--train_file train_file_path
```
## run_clm.py ## run_clm.py
This script trains a causal language model. This script trains a causal language model.
...@@ -61,3 +69,12 @@ python run_clm.py \ ...@@ -61,3 +69,12 @@ python run_clm.py \
--dataset_name wikitext \ --dataset_name wikitext \
--dataset_config_name wikitext-103-raw-v1 --dataset_config_name wikitext-103-raw-v1
``` ```
When using a custom dataset, the validation file can be separately passed as an input argument. Otherwise some split (customizable) of training data is used as validation.
```
python run_clm.py \
--model_name_or_path distilgpt2 \
--output_dir output \
--train_file train_file_path
```
...@@ -37,6 +37,7 @@ import datasets ...@@ -37,6 +37,7 @@ import datasets
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from datasets import load_dataset from datasets import load_dataset
from sklearn.model_selection import train_test_split
import transformers import transformers
from transformers import ( from transformers import (
...@@ -429,7 +430,18 @@ def main(): ...@@ -429,7 +430,18 @@ def main():
) )
train_dataset = lm_datasets["train"] train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"] if data_args.validation_file is not None:
eval_dataset = lm_datasets["validation"]
else:
logger.info(
f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage
)
eval_dataset = train_dataset.select(val_indices)
train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
......
...@@ -39,6 +39,7 @@ import datasets ...@@ -39,6 +39,7 @@ import datasets
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from datasets import load_dataset from datasets import load_dataset
from sklearn.model_selection import train_test_split
import transformers import transformers
from transformers import ( from transformers import (
...@@ -363,6 +364,7 @@ def main(): ...@@ -363,6 +364,7 @@ def main():
if extension == "txt": if extension == "txt":
extension = "text" extension = "text"
raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
# endregion # endregion
...@@ -488,9 +490,22 @@ def main(): ...@@ -488,9 +490,22 @@ def main():
) )
train_dataset = tokenized_datasets["train"] train_dataset = tokenized_datasets["train"]
if data_args.validation_file is not None:
eval_dataset = tokenized_datasets["validation"]
else:
logger.info(
f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage
)
eval_dataset = train_dataset.select(val_indices)
train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
eval_dataset = tokenized_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment