Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
62d84760
Unverified
Commit
62d84760
authored
Mar 08, 2022
by
Joao Gante
Committed by
GitHub
Mar 08, 2022
Browse files
Update TF multiple choice example (#15868)
parent
ab2f8d12
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
50 deletions
+89
-50
examples/tensorflow/multiple-choice/run_swag.py
examples/tensorflow/multiple-choice/run_swag.py
+89
-50
No files found.
examples/tensorflow/multiple-choice/run_swag.py
View file @
62d84760
...
@@ -24,10 +24,9 @@ import sys
...
@@ -24,10 +24,9 @@ import sys
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
itertools
import
chain
from
itertools
import
chain
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
datasets
import
datasets
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
datasets
import
load_dataset
from
datasets
import
load_dataset
...
@@ -37,12 +36,15 @@ from transformers import (
...
@@ -37,12 +36,15 @@ from transformers import (
TF2_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
AutoConfig
,
AutoConfig
,
AutoTokenizer
,
AutoTokenizer
,
DefaultDataCollator
,
HfArgumentParser
,
HfArgumentParser
,
TFAutoModelForMultipleChoice
,
TFAutoModelForMultipleChoice
,
TFTrainingArguments
,
TFTrainingArguments
,
create_optimizer
,
create_optimizer
,
set_seed
,
set_seed
,
)
)
from
transformers.file_utils
import
PaddingStrategy
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
transformers.utils
import
check_min_version
from
transformers.utils
import
check_min_version
...
@@ -65,51 +67,61 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
...
@@ -65,51 +67,61 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
self
.
model
.
save_pretrained
(
self
.
output_dir
)
self
.
model
.
save_pretrained
(
self
.
output_dir
)
def
convert_dataset_for_tensorflow
(
@
dataclass
dataset
,
non_label_column_names
,
batch_size
,
dataset_mode
=
"variable_batch"
,
shuffle
=
True
,
drop_remainder
=
True
class
DataCollatorForMultipleChoice
:
):
"""
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
Data collator that will dynamically pad the inputs for multiple choice received.
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
"""
def
densify_ragged_batch
(
features
,
label
=
None
):
tokenizer
:
PreTrainedTokenizerBase
features
=
{
padding
:
Union
[
bool
,
str
,
PaddingStrategy
]
=
True
feature
:
ragged_tensor
.
to_tensor
(
shape
=
batch_shape
[
feature
])
for
feature
,
ragged_tensor
in
features
.
items
()
max_length
:
Optional
[
int
]
=
None
}
pad_to_multiple_of
:
Optional
[
int
]
=
None
if
label
is
None
:
return
features
def
__call__
(
self
,
features
):
else
:
label_name
=
"label"
if
"label"
in
features
[
0
].
keys
()
else
"labels"
return
features
,
label
labels
=
[
feature
.
pop
(
label_name
)
for
feature
in
features
]
batch_size
=
len
(
features
)
feature_keys
=
list
(
set
(
dataset
.
features
.
keys
())
-
set
(
non_label_column_names
+
[
"label"
]))
num_choices
=
len
(
features
[
0
][
"input_ids"
])
if
dataset_mode
==
"variable_batch"
:
flattened_features
=
[
batch_shape
=
{
key
:
None
for
key
in
feature_keys
}
[{
k
:
v
[
i
]
for
k
,
v
in
feature
.
items
()}
for
i
in
range
(
num_choices
)]
for
feature
in
features
data
=
{
key
:
tf
.
ragged
.
constant
(
dataset
[
key
])
for
key
in
feature_keys
}
]
elif
dataset_mode
==
"constant_batch"
:
flattened_features
=
list
(
chain
(
*
flattened_features
))
data
=
{
key
:
tf
.
ragged
.
constant
(
dataset
[
key
])
for
key
in
feature_keys
}
batch_shape
=
{
batch
=
self
.
tokenizer
.
pad
(
key
:
tf
.
concat
(([
batch_size
],
ragged_tensor
.
bounding_shape
()[
1
:]),
axis
=
0
)
flattened_features
,
for
key
,
ragged_tensor
in
data
.
items
()
padding
=
self
.
padding
,
}
max_length
=
self
.
max_length
,
else
:
pad_to_multiple_of
=
self
.
pad_to_multiple_of
,
raise
ValueError
(
"Unknown dataset mode!"
)
return_tensors
=
"tf"
,
)
if
"label"
in
dataset
.
features
:
# Un-flatten
labels
=
tf
.
convert_to_tensor
(
np
.
array
(
dataset
[
"label"
]))
batch
=
{
k
:
tf
.
reshape
(
v
,
(
batch_size
,
num_choices
,
-
1
))
for
k
,
v
in
batch
.
items
()}
tf_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
data
,
labels
))
# Add back labels
else
:
batch
[
"labels"
]
=
tf
.
convert_to_tensor
(
labels
,
dtype
=
tf
.
int64
)
tf_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
data
)
return
batch
if
shuffle
:
tf_dataset
=
tf_dataset
.
shuffle
(
buffer_size
=
len
(
dataset
))
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
tf_dataset
=
(
tf_dataset
.
with_options
(
options
)
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
.
map
(
densify_ragged_batch
)
)
return
tf_dataset
# endregion
# endregion
...
@@ -382,6 +394,12 @@ def main():
...
@@ -382,6 +394,12 @@ def main():
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
)
if
data_args
.
pad_to_max_length
:
data_collator
=
DefaultDataCollator
(
return_tensors
=
"tf"
)
else
:
# custom class defined above, as HF has no data collator for multiple choice
data_collator
=
DataCollatorForMultipleChoice
(
tokenizer
)
# endregion
# endregion
with
training_args
.
strategy
.
scope
():
with
training_args
.
strategy
.
scope
():
...
@@ -417,12 +435,26 @@ def main():
...
@@ -417,12 +435,26 @@ def main():
# region Training
# region Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
tf_train_dataset
=
convert_dataset_for_tensorflow
(
dataset_exclude_cols
=
set
(
non_label_columns
+
[
"label"
])
train_dataset
,
non_label_column_names
=
non_label_columns
,
batch_size
=
total_train_batch_size
tf_train_dataset
=
train_dataset
.
to_tf_dataset
(
columns
=
[
col
for
col
in
train_dataset
.
column_names
if
col
not
in
dataset_exclude_cols
],
shuffle
=
True
,
batch_size
=
total_train_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols
=
"label"
if
"label"
in
train_dataset
.
column_names
else
None
,
)
)
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
validation_data
=
convert_dataset_for_tensorflow
(
validation_data
=
eval_dataset
.
to_tf_dataset
(
eval_dataset
,
non_label_column_names
=
non_label_columns
,
batch_size
=
total_eval_batch_size
columns
=
[
col
for
col
in
eval_dataset
.
column_names
if
col
not
in
dataset_exclude_cols
],
shuffle
=
False
,
batch_size
=
total_eval_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols
=
"label"
if
"label"
in
eval_dataset
.
column_names
else
None
,
)
)
else
:
else
:
validation_data
=
None
validation_data
=
None
...
@@ -436,9 +468,16 @@ def main():
...
@@ -436,9 +468,16 @@ def main():
# region Evaluation
# region Evaluation
if
training_args
.
do_eval
and
not
training_args
.
do_train
:
if
training_args
.
do_eval
and
not
training_args
.
do_train
:
dataset_exclude_cols
=
set
(
non_label_columns
+
[
"label"
])
# Do a standalone evaluation pass
# Do a standalone evaluation pass
tf_eval_dataset
=
convert_dataset_for_tensorflow
(
tf_eval_dataset
=
eval_dataset
.
to_tf_dataset
(
eval_dataset
,
non_label_column_names
=
non_label_columns
,
batch_size
=
total_eval_batch_size
columns
=
[
col
for
col
in
eval_dataset
.
column_names
if
col
not
in
dataset_exclude_cols
],
shuffle
=
False
,
batch_size
=
total_eval_batch_size
,
collate_fn
=
data_collator
,
drop_remainder
=
True
,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols
=
"label"
if
"label"
in
eval_dataset
.
column_names
else
None
,
)
)
model
.
evaluate
(
tf_eval_dataset
)
model
.
evaluate
(
tf_eval_dataset
)
# endregion
# endregion
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment