Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e9057c4d
Commit
e9057c4d
authored
Nov 14, 2020
by
Tianqi Liu
Committed by
A. Unique TensorFlower
Nov 14, 2020
Browse files
Add supports of using translated data in XTREME benchmarks.
PiperOrigin-RevId: 342448071
parent
f409e4d0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
369 additions
and
99 deletions
+369
-99
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+161
-36
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+72
-41
official/nlp/data/squad_lib.py
official/nlp/data/squad_lib.py
+13
-2
official/nlp/data/squad_lib_sp.py
official/nlp/data/squad_lib_sp.py
+14
-2
official/nlp/data/tagging_data_lib.py
official/nlp/data/tagging_data_lib.py
+78
-6
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+31
-12
No files found.
official/nlp/data/classifier_data_lib.py
View file @
e9057c4d
...
@@ -938,45 +938,104 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -938,45 +938,104 @@ class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
translated_data_dir
=
None
,
only_use_en_dev
=
True
):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training and testing data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
XtremePawsxProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
translated_data_dir
=
translated_data_dir
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
self
.
translated_data_dir
is
None
:
guid
=
"train-%d"
%
i
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
text_a
=
self
.
process_text_fn
(
line
[
0
])
for
i
,
line
in
enumerate
(
lines
):
text_b
=
self
.
process_text_fn
(
line
[
1
])
guid
=
"train-%d"
%
i
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
text_b
=
self
.
process_text_fn
(
line
[
1
])
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-train"
,
f
"en-
{
lang
}
-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"train-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
4
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
self
.
only_use_en_dev
:
guid
=
"dev-%d"
%
i
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
text_a
=
self
.
process_text_fn
(
line
[
0
])
for
i
,
line
in
enumerate
(
lines
):
text_b
=
self
.
process_text_fn
(
line
[
1
])
guid
=
"dev-%d"
%
i
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
text_b
=
self
.
process_text_fn
(
line
[
1
])
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"dev-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
examples_by_lang
[
lang
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-
%d"
%
i
guid
=
f
"test-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
label
=
"0"
examples_by_lang
[
lang
].
append
(
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
not
None
:
for
lang
in
self
.
supported_languages
:
if
lang
==
"en"
:
continue
examples_by_lang
[
f
"
{
lang
}
-en"
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-test"
,
f
"test-
{
lang
}
-en-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-en-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
"0"
examples_by_lang
[
f
"
{
lang
}
-en"
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
...
@@ -996,45 +1055,111 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -996,45 +1055,111 @@ class XtremeXnliProcessor(DataProcessor):
"ur"
,
"vi"
,
"zh"
"ur"
,
"vi"
,
"zh"
]
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
translated_data_dir
=
None
,
only_use_en_dev
=
True
):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
XtremeXnliProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
translated_data_dir
=
translated_data_dir
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
self
.
translated_data_dir
is
None
:
guid
=
"train-%d"
%
i
for
i
,
line
in
enumerate
(
lines
):
text_a
=
self
.
process_text_fn
(
line
[
0
])
guid
=
"train-%d"
%
i
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
label
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
label
=
self
.
process_text_fn
(
line
[
2
])
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-train"
,
f
"en-
{
lang
}
-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"train-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
4
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
self
.
only_use_en_dev
:
guid
=
"dev-%d"
%
i
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
text_a
=
self
.
process_text_fn
(
line
[
0
])
for
i
,
line
in
enumerate
(
lines
):
text_b
=
self
.
process_text_fn
(
line
[
1
])
guid
=
"dev-%d"
%
i
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
text_b
=
self
.
process_text_fn
(
line
[
1
])
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"dev-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
examples_by_lang
[
lang
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
guid
=
f
"test-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
not
None
:
for
lang
in
self
.
supported_languages
:
if
lang
==
"en"
:
continue
examples_by_lang
[
f
"
{
lang
}
-en"
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-test"
,
f
"test-
{
lang
}
-en-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-en-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
"contradiction"
examples_by_lang
[
f
"
{
lang
}
-en"
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
...
...
official/nlp/data/create_finetuning_data.py
View file @
e9057c4d
...
@@ -46,20 +46,19 @@ flags.DEFINE_string(
...
@@ -46,20 +46,19 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) "
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task."
)
"for the task."
)
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
flags
.
DEFINE_enum
(
[
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"classification_task_name"
,
"MNLI"
,
[
"QQP"
,
"RTE"
,
"SST-2
"
,
"
ST
S-
B
"
,
"
W
NLI"
,
"
XNLI
"
,
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC
"
,
"
PAW
S-
X
"
,
"
Q
NLI"
,
"
QQP"
,
"RTE
"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
"The name of the task to train BERT classifier. The "
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
"PAWS-X."
)
# MNLI task-specific flag.
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
"The type of MNLI dataset."
)
# XNLI task-specific flag.
# XNLI task-specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
...
@@ -73,6 +72,12 @@ flags.DEFINE_string(
...
@@ -73,6 +72,12 @@ flags.DEFINE_string(
"Language of training data for PAWS-X task. If the value is 'all', the data "
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags
.
DEFINE_string
(
"translated_input_data_dir"
,
None
,
"The translated input data dir. Should contain the .tsv files (or other "
"data files) for the task."
)
# Retrieval task-specific flags.
# Retrieval task-specific flags.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
"The name of sentence retrieval task for scoring"
)
...
@@ -81,11 +86,19 @@ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
...
@@ -81,11 +86,19 @@ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
"The name of BERT tagging (token classification) task."
)
flags
.
DEFINE_bool
(
"tagging_only_use_en_train"
,
True
,
"Whether only use english training data in tagging."
)
# BERT Squad task-specific flags.
# BERT Squad task-specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
"The input data file in for generating training data for BERT squad task."
)
flags
.
DEFINE_string
(
"translated_squad_data_folder"
,
None
,
"The translated data folder for generating training data for BERT squad "
"task."
)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
"doc_stride"
,
128
,
"doc_stride"
,
128
,
"When splitting up a long document into chunks, how much stride to "
"When splitting up a long document into chunks, how much stride to "
...
@@ -105,6 +118,9 @@ flags.DEFINE_bool(
...
@@ -105,6 +118,9 @@ flags.DEFINE_bool(
"If true, then data will be preprocessed in a paragraph, query, class order"
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order."
)
" instead of the BERT-style class, paragraph, query order."
)
# XTREME specific flags.
flags
.
DEFINE_bool
(
"only_use_en_dev"
,
True
,
"Whether only use english dev data."
)
# Shared flags across BERT fine-tuning tasks.
# Shared flags across BERT fine-tuning tasks.
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
"The vocabulary file that the BERT model was trained on."
)
...
@@ -148,16 +164,16 @@ flags.DEFINE_enum(
...
@@ -148,16 +164,16 @@ flags.DEFINE_enum(
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses SentencePiece tokenizer."
)
"while ALBERT uses SentencePiece tokenizer."
)
flags
.
DEFINE_string
(
"tfds_params"
,
""
,
flags
.
DEFINE_string
(
"Comma-separated list of TFDS parameter assigments for "
"tfds_params"
,
""
,
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation)."
)
"see the TfdsProcessor class documentation)."
)
def
generate_classifier_dataset
():
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
"""Generates classifier dataset and returns input meta data."""
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
or
FLAGS
.
tfds_params
)
FLAGS
.
tfds_params
)
if
FLAGS
.
tokenization
==
"WordPiece"
:
if
FLAGS
.
tokenization
==
"WordPiece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
...
@@ -171,8 +187,7 @@ def generate_classifier_dataset():
...
@@ -171,8 +187,7 @@ def generate_classifier_dataset():
if
FLAGS
.
tfds_params
:
if
FLAGS
.
tfds_params
:
processor
=
classifier_data_lib
.
TfdsProcessor
(
processor
=
classifier_data_lib
.
TfdsProcessor
(
tfds_params
=
FLAGS
.
tfds_params
,
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
process_text_fn
=
processor_text_fn
)
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
,
processor
,
None
,
None
,
...
@@ -190,29 +205,40 @@ def generate_classifier_dataset():
...
@@ -190,29 +205,40 @@ def generate_classifier_dataset():
"imdb"
:
"imdb"
:
classifier_data_lib
.
ImdbProcessor
,
classifier_data_lib
.
ImdbProcessor
,
"mnli"
:
"mnli"
:
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
mnli_type
=
FLAGS
.
mnli_type
),
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
"qnli"
:
classifier_data_lib
.
QnliProcessor
,
classifier_data_lib
.
QnliProcessor
,
"qqp"
:
classifier_data_lib
.
QqpProcessor
,
"qqp"
:
"rte"
:
classifier_data_lib
.
RteProcessor
,
classifier_data_lib
.
QqpProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
functools
.
partial
(
language
=
FLAGS
.
xnli_language
),
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
"paws-x"
:
"paws-x"
:
functools
.
partial
(
classifier_data_lib
.
PawsxProcessor
,
functools
.
partial
(
language
=
FLAGS
.
pawsx_language
),
classifier_data_lib
.
PawsxProcessor
,
"wnli"
:
classifier_data_lib
.
WnliProcessor
,
language
=
FLAGS
.
pawsx_language
),
"wnli"
:
classifier_data_lib
.
WnliProcessor
,
"xtreme-xnli"
:
"xtreme-xnli"
:
functools
.
partial
(
classifier_data_lib
.
XtremeXnliProcessor
),
functools
.
partial
(
classifier_data_lib
.
XtremeXnliProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"xtreme-paws-x"
:
"xtreme-paws-x"
:
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
)
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
)
}
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
if
task_name
not
in
processors
:
...
@@ -243,8 +269,7 @@ def generate_regression_dataset():
...
@@ -243,8 +269,7 @@ def generate_regression_dataset():
if
FLAGS
.
tfds_params
:
if
FLAGS
.
tfds_params
:
processor
=
classifier_data_lib
.
TfdsProcessor
(
processor
=
classifier_data_lib
.
TfdsProcessor
(
tfds_params
=
FLAGS
.
tfds_params
,
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
process_text_fn
=
processor_text_fn
)
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
,
processor
,
None
,
None
,
...
@@ -265,6 +290,7 @@ def generate_squad_dataset():
...
@@ -265,6 +290,7 @@ def generate_squad_dataset():
input_file_path
=
FLAGS
.
squad_data_file
,
input_file_path
=
FLAGS
.
squad_data_file
,
vocab_file_path
=
FLAGS
.
vocab_file
,
vocab_file_path
=
FLAGS
.
vocab_file
,
output_path
=
FLAGS
.
train_data_output_path
,
output_path
=
FLAGS
.
train_data_output_path
,
translated_input_folder
=
FLAGS
.
translated_squad_data_folder
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
max_query_length
=
FLAGS
.
max_query_length
,
...
@@ -277,6 +303,7 @@ def generate_squad_dataset():
...
@@ -277,6 +303,7 @@ def generate_squad_dataset():
input_file_path
=
FLAGS
.
squad_data_file
,
input_file_path
=
FLAGS
.
squad_data_file
,
sp_model_file
=
FLAGS
.
sp_model_file
,
sp_model_file
=
FLAGS
.
sp_model_file
,
output_path
=
FLAGS
.
train_data_output_path
,
output_path
=
FLAGS
.
train_data_output_path
,
translated_input_folder
=
FLAGS
.
translated_squad_data_folder
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
max_query_length
=
FLAGS
.
max_query_length
,
...
@@ -310,19 +337,23 @@ def generate_retrieval_dataset():
...
@@ -310,19 +337,23 @@ def generate_retrieval_dataset():
processor
=
processors
[
task_name
](
process_text_fn
=
processor_text_fn
)
processor
=
processors
[
task_name
](
process_text_fn
=
processor_text_fn
)
return
sentence_retrieval_lib
.
generate_sentence_retrevial_tf_record
(
return
sentence_retrieval_lib
.
generate_sentence_retrevial_tf_record
(
processor
,
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
input_data_dir
,
FLAGS
.
test_data_output_path
,
FLAGS
.
max_seq_length
)
tokenizer
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
FLAGS
.
max_seq_length
)
def
generate_tagging_dataset
():
def
generate_tagging_dataset
():
"""Generates tagging dataset."""
"""Generates tagging dataset."""
processors
=
{
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"panx"
:
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
functools
.
partial
(
tagging_data_lib
.
PanxProcessor
,
only_use_en_train
=
FLAGS
.
tagging_only_use_en_train
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"udpos"
:
functools
.
partial
(
tagging_data_lib
.
UdposProcessor
,
only_use_en_train
=
FLAGS
.
tagging_only_use_en_train
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
}
}
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
if
task_name
not
in
processors
:
if
task_name
not
in
processors
:
...
...
official/nlp/data/squad_lib.py
View file @
e9057c4d
...
@@ -158,11 +158,20 @@ class FeatureWriter(object):
...
@@ -158,11 +158,20 @@ class FeatureWriter(object):
self
.
_writer
.
close
()
self
.
_writer
.
close
()
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
,
translated_input_folder
=
None
):
"""Read a SQuAD json file into a list of SquadExample."""
"""Read a SQuAD json file into a list of SquadExample."""
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
if
translated_input_folder
is
not
None
:
translated_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
translated_input_folder
,
"*.json"
))
for
file
in
translated_files
:
with
tf
.
io
.
gfile
.
GFile
(
file
,
"r"
)
as
reader
:
input_data
.
extend
(
json
.
load
(
reader
)[
"data"
])
def
is_whitespace
(
c
):
def
is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
True
...
@@ -930,6 +939,7 @@ def _compute_softmax(scores):
...
@@ -930,6 +939,7 @@ def _compute_softmax(scores):
def
generate_tf_record_from_json_file
(
input_file_path
,
def
generate_tf_record_from_json_file
(
input_file_path
,
vocab_file_path
,
vocab_file_path
,
output_path
,
output_path
,
translated_input_folder
=
None
,
max_seq_length
=
384
,
max_seq_length
=
384
,
do_lower_case
=
True
,
do_lower_case
=
True
,
max_query_length
=
64
,
max_query_length
=
64
,
...
@@ -940,7 +950,8 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -940,7 +950,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
input_file
=
input_file_path
,
input_file
=
input_file_path
,
is_training
=
True
,
is_training
=
True
,
version_2_with_negative
=
version_2_with_negative
)
version_2_with_negative
=
version_2_with_negative
,
translated_input_folder
=
translated_input_folder
)
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_file_path
,
do_lower_case
=
do_lower_case
)
vocab_file
=
vocab_file_path
,
do_lower_case
=
do_lower_case
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
)
...
...
official/nlp/data/squad_lib_sp.py
View file @
e9057c4d
...
@@ -109,12 +109,22 @@ class InputFeatures(object):
...
@@ -109,12 +109,22 @@ class InputFeatures(object):
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
,
translated_input_folder
=
None
):
"""Read a SQuAD json file into a list of SquadExample."""
"""Read a SQuAD json file into a list of SquadExample."""
del
version_2_with_negative
del
version_2_with_negative
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
if
translated_input_folder
is
not
None
:
translated_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
translated_input_folder
,
"*.json"
))
for
file
in
translated_files
:
with
tf
.
io
.
gfile
.
GFile
(
file
,
"r"
)
as
reader
:
input_data
.
extend
(
json
.
load
(
reader
)[
"data"
])
examples
=
[]
examples
=
[]
for
entry
in
input_data
:
for
entry
in
input_data
:
for
paragraph
in
entry
[
"paragraphs"
]:
for
paragraph
in
entry
[
"paragraphs"
]:
...
@@ -922,6 +932,7 @@ class FeatureWriter(object):
...
@@ -922,6 +932,7 @@ class FeatureWriter(object):
def
generate_tf_record_from_json_file
(
input_file_path
,
def
generate_tf_record_from_json_file
(
input_file_path
,
sp_model_file
,
sp_model_file
,
output_path
,
output_path
,
translated_input_folder
=
None
,
max_seq_length
=
384
,
max_seq_length
=
384
,
do_lower_case
=
True
,
do_lower_case
=
True
,
max_query_length
=
64
,
max_query_length
=
64
,
...
@@ -932,7 +943,8 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -932,7 +943,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
input_file
=
input_file_path
,
input_file
=
input_file_path
,
is_training
=
True
,
is_training
=
True
,
version_2_with_negative
=
version_2_with_negative
)
version_2_with_negative
=
version_2_with_negative
,
translated_input_folder
=
translated_input_folder
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
sp_model_file
)
sp_model_file
=
sp_model_file
)
train_writer
=
FeatureWriter
(
train_writer
=
FeatureWriter
(
...
...
official/nlp/data/tagging_data_lib.py
View file @
e9057c4d
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
classifier_data_lib
from
official.nlp.data
import
classifier_data_lib
# A negative label id for the padding label, which will not contribute
# A negative label id for the padding label, which will not contribute
...
@@ -89,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor):
...
@@ -89,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor):
"tr"
,
"et"
,
"fi"
,
"hu"
"tr"
,
"et"
,
"fi"
,
"hu"
]
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
only_use_en_train
=
True
,
only_use_en_dev
=
True
):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
PanxProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
only_use_en_train
=
only_use_en_train
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
examples
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
if
not
self
.
only_use_en_train
:
for
language
in
self
.
supported_languages
:
if
language
==
"en"
:
continue
examples
.
extend
(
_read_one_file
(
os
.
path
.
join
(
data_dir
,
f
"train-
{
language
}
.tsv"
),
self
.
get_labels
()))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
examples
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
if
not
self
.
only_use_en_dev
:
for
language
in
self
.
supported_languages
:
if
language
==
"en"
:
continue
examples
.
extend
(
_read_one_file
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
language
}
.tsv"
),
self
.
get_labels
()))
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
examples_dict
=
{}
...
@@ -120,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor):
...
@@ -120,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor):
"ta"
,
"te"
,
"th"
,
"tl"
,
"tr"
,
"ur"
,
"vi"
,
"yo"
,
"zh"
"ta"
,
"te"
,
"th"
,
"tl"
,
"tr"
,
"ur"
,
"vi"
,
"yo"
,
"zh"
]
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
only_use_en_train
=
True
,
only_use_en_dev
=
True
):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
UdposProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
only_use_en_train
=
only_use_en_train
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
if
self
.
only_use_en_train
:
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
examples
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
else
:
examples
=
[]
# Uses glob because some languages are missing in train.
for
filepath
in
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
data_dir
,
"train-*.tsv"
)):
examples
.
extend
(
_read_one_file
(
filepath
,
self
.
get_labels
()))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
if
self
.
only_use_en_dev
:
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
examples
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
else
:
examples
=
[]
for
filepath
in
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
data_dir
,
"dev-*.tsv"
)):
examples
.
extend
(
_read_one_file
(
filepath
,
self
.
get_labels
()))
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
examples_dict
=
{}
...
...
official/nlp/tasks/sentence_prediction.py
View file @
e9057c4d
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Sentence prediction (classification) task."""
"""Sentence prediction (classification) task."""
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Optional
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
dataclasses
...
@@ -159,8 +159,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -159,8 +159,7 @@ class SentencePredictionTask(base_task.Task):
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
if
self
.
metric_type
==
'matthews_corrcoef'
:
if
self
.
metric_type
==
'matthews_corrcoef'
:
logs
.
update
({
logs
.
update
({
'sentence_prediction'
:
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
# Ensure one prediction along batch dimension.
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
'labels'
:
'labels'
:
labels
,
labels
,
...
@@ -228,32 +227,34 @@ class SentencePredictionTask(base_task.Task):
...
@@ -228,32 +227,34 @@ class SentencePredictionTask(base_task.Task):
ckpt_dir_or_file
)
ckpt_dir_or_file
)
def
predict
(
task
:
SentencePredictionTask
,
params
:
cfg
.
DataConfig
,
def
predict
(
task
:
SentencePredictionTask
,
model
:
tf
.
keras
.
Model
)
->
List
[
Union
[
int
,
float
]]:
params
:
cfg
.
DataConfig
,
model
:
tf
.
keras
.
Model
,
params_aug
:
Optional
[
cfg
.
DataConfig
]
=
None
,
test_time_aug_wgt
:
float
=
0.3
)
->
List
[
Union
[
int
,
float
]]:
"""Predicts on the input data.
"""Predicts on the input data.
Args:
Args:
task: A `SentencePredictionTask` object.
task: A `SentencePredictionTask` object.
params: A `cfg.DataConfig` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
model: A keras.Model.
params_aug: A `cfg.DataConfig` object for augmented data.
test_time_aug_wgt: Test time augmentation weight. The prediction score will
use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
augmented prediction.
Returns:
Returns:
A list of predictions with length of `num_examples`. For regression task,
A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task,
each element in the list is the predicted score; for classification task,
each element is the predicted class id.
each element is the predicted class id.
"""
"""
is_regression
=
task
.
task_config
.
model
.
num_classes
==
1
def
predict_step
(
inputs
):
def
predict_step
(
inputs
):
"""Replicated prediction calculation."""
"""Replicated prediction calculation."""
x
,
_
=
inputs
x
,
_
=
inputs
example_id
=
x
.
pop
(
'example_id'
)
example_id
=
x
.
pop
(
'example_id'
)
outputs
=
task
.
inference_step
(
x
,
model
)
outputs
=
task
.
inference_step
(
x
,
model
)
if
is_regression
:
return
dict
(
example_id
=
example_id
,
predictions
=
outputs
)
return
dict
(
example_id
=
example_id
,
predictions
=
outputs
)
else
:
return
dict
(
example_id
=
example_id
,
predictions
=
tf
.
argmax
(
outputs
,
axis
=-
1
))
def
aggregate_fn
(
state
,
outputs
):
def
aggregate_fn
(
state
,
outputs
):
"""Concatenates model's outputs."""
"""Concatenates model's outputs."""
...
@@ -272,4 +273,22 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
...
@@ -272,4 +273,22 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
# When running on TPU POD, the order of output cannot be maintained,
# When running on TPU POD, the order of output cannot be maintained,
# so we need to sort by example_id.
# so we need to sort by example_id.
outputs
=
sorted
(
outputs
,
key
=
lambda
x
:
x
[
0
])
outputs
=
sorted
(
outputs
,
key
=
lambda
x
:
x
[
0
])
return
[
x
[
1
]
for
x
in
outputs
]
is_regression
=
task
.
task_config
.
model
.
num_classes
==
1
if
params_aug
is
not
None
:
dataset_aug
=
orbit
.
utils
.
make_distributed_dataset
(
tf
.
distribute
.
get_strategy
(),
task
.
build_inputs
,
params_aug
)
outputs_aug
=
utils
.
predict
(
predict_step
,
aggregate_fn
,
dataset_aug
)
outputs_aug
=
sorted
(
outputs_aug
,
key
=
lambda
x
:
x
[
0
])
if
is_regression
:
return
[(
1.
-
test_time_aug_wgt
)
*
x
[
1
]
+
test_time_aug_wgt
*
y
[
1
]
for
x
,
y
in
zip
(
outputs
,
outputs_aug
)]
else
:
return
[
tf
.
argmax
(
(
1.
-
test_time_aug_wgt
)
*
x
[
1
]
+
test_time_aug_wgt
*
y
[
1
],
axis
=-
1
)
for
x
,
y
in
zip
(
outputs
,
outputs_aug
)
]
if
is_regression
:
return
[
x
[
1
]
for
x
in
outputs
]
else
:
return
[
tf
.
argmax
(
x
[
1
],
axis
=-
1
)
for
x
in
outputs
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment