"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7419d807ff3d2ca45757c9e3090388b721e131ce"
Unverified Commit 074d6b75 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

Simplify column_names in run_clm/mlm (#21382)

* simplify column_names in run_clm

* simplify column_names in run_mlm

* minor
parent c21298a6
...@@ -419,15 +419,9 @@ def main(): ...@@ -419,15 +419,9 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
if data_args.streaming: column_names = list(raw_datasets["train"].features)
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
if data_args.streaming: column_names = list(raw_datasets["validation"].features)
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
......
...@@ -405,15 +405,9 @@ def main(): ...@@ -405,15 +405,9 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
if data_args.streaming: column_names = list(raw_datasets["train"].features)
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
if data_args.streaming: column_names = list(raw_datasets["validation"].features)
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None: if data_args.max_seq_length is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment