Unverified Commit 472a8676 authored by kumapo's avatar kumapo Committed by GitHub
Browse files

Add text_column_name and label_column_name to run_ner and run_ner_no_trainer args (#12083)

* Add text_column_name and label_column_name to run_ner args

* Minor fix: grouping for text and label column name
parent bc6f51e5
...@@ -106,6 +106,12 @@ class DataTrainingArguments: ...@@ -106,6 +106,12 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
) )
text_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
)
label_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
...@@ -249,10 +255,20 @@ def main(): ...@@ -249,10 +255,20 @@ def main():
else: else:
column_names = datasets["validation"].column_names column_names = datasets["validation"].column_names
features = datasets["validation"].features features = datasets["validation"].features
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
label_column_name = ( if data_args.text_column_name is not None:
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] text_column_name = data_args.text_column_name
) elif "tokens" in column_names:
text_column_name = "tokens"
else:
text_column_name = column_names[0]
if data_args.label_column_name is not None:
label_column_name = data_args.label_column_name
elif f"{data_args.task_name}_tags" in column_names:
label_column_name = f"{data_args.task_name}_tags"
else:
label_column_name = column_names[1]
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels. # unique labels.
......
...@@ -75,6 +75,12 @@ def parse_args(): ...@@ -75,6 +75,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
) )
parser.add_argument(
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
)
parser.add_argument(
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
)
parser.add_argument( parser.add_argument(
"--max_length", "--max_length",
type=int, type=int,
...@@ -259,8 +265,20 @@ def main(): ...@@ -259,8 +265,20 @@ def main():
else: else:
column_names = raw_datasets["validation"].column_names column_names = raw_datasets["validation"].column_names
features = raw_datasets["validation"].features features = raw_datasets["validation"].features
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
label_column_name = f"{args.task_name}_tags" if f"{args.task_name}_tags" in column_names else column_names[1] if data_args.text_column_name is not None:
text_column_name = data_args.text_column_name
elif "tokens" in column_names:
text_column_name = "tokens"
else:
text_column_name = column_names[0]
if data_args.label_column_name is not None:
label_column_name = data_args.label_column_name
elif f"{data_args.task_name}_tags" in column_names:
label_column_name = f"{data_args.task_name}_tags"
else:
label_column_name = column_names[1]
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels. # unique labels.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment