Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
20d6931e
Unverified
Commit
20d6931e
authored
Apr 30, 2021
by
Matt
Committed by
GitHub
Apr 30, 2021
Browse files
Update TF text classification example (#11496)
Big refactor, fixes and multi-GPU/TPU support
parent
8b945ef0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
201 additions
and
187 deletions
+201
-187
examples/tensorflow/text-classification/README.md
examples/tensorflow/text-classification/README.md
+14
-0
examples/tensorflow/text-classification/run_text_classification.py
...tensorflow/text-classification/run_text_classification.py
+182
-185
src/transformers/training_args_tf.py
src/transformers/training_args_tf.py
+5
-2
No files found.
examples/tensorflow/text-classification/README.md
View file @
20d6931e
...
...
@@ -54,6 +54,20 @@ After training, the model will be saved to `--output_dir`. Once your model is tr
by calling the script without a
`--train_file`
or
`--validation_file`
; simply pass it the output_dir containing
the trained model and a
`--test_file`
and it will write its predictions to a text file for you.
### Multi-GPU and TPU usage
By default, the script uses a
`MirroredStrategy`
and will use multiple GPUs effectively if they are available. TPUs
can also be used by passing the name of the TPU resource with the
`--tpu`
argument.
### Memory usage and data loading
One thing to note is that all data is loaded into memory in this script. Most text classification datasets are small
enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle
data streaming. This is particularly challenging for TPUs, given the stricter requirements and the sheer volume of data
required to keep them fed. A full explanation of all the possible pitfalls is a bit beyond this example script and
README, but for more information you can see the 'Input Datasets' section of
[
this document
](
https://www.tensorflow.org/guide/tpu
)
.
### Example command
```
python run_text_classification.py \
...
...
examples/tensorflow/text-classification/run_text_classification.py
View file @
20d6931e
...
...
@@ -18,10 +18,8 @@
import
logging
import
os
import
random
import
sys
from
dataclasses
import
dataclass
,
field
from
math
import
ceil
from
pathlib
import
Path
from
typing
import
Optional
...
...
@@ -34,7 +32,7 @@ from transformers import (
HfArgumentParser
,
PretrainedConfig
,
TFAutoModelForSequenceClassification
,
TrainingArguments
,
TF
TrainingArguments
,
set_seed
,
)
from
transformers.file_utils
import
CONFIG_NAME
,
TF2_WEIGHTS_NAME
...
...
@@ -48,65 +46,6 @@ logger = logging.getLogger(__name__)
# region Helper classes
class
DataSequence
(
tf
.
keras
.
utils
.
Sequence
):
# We use a Sequence object to load the data. Although it's completely possible to load your data as Numpy/TF arrays
# and pass those straight to the Model, this constrains you in a couple of ways. Most notably, it requires all
# the data to be padded to the length of the longest input example, and it also requires the whole dataset to be
# loaded into memory. If these aren't major problems for you, you can skip the sequence object in your own code!
def
__init__
(
self
,
dataset
,
non_label_column_names
,
batch_size
,
labels
,
shuffle
=
True
):
super
().
__init__
()
# Retain all of the columns not present in the original data - these are the ones added by the tokenizer
self
.
data
=
{
key
:
dataset
[
key
]
for
key
in
dataset
.
features
.
keys
()
if
key
not
in
non_label_column_names
and
key
!=
"label"
}
data_lengths
=
{
len
(
array
)
for
array
in
self
.
data
.
values
()}
assert
len
(
data_lengths
)
==
1
,
"Dataset arrays differ in length!"
self
.
data_length
=
data_lengths
.
pop
()
self
.
num_batches
=
ceil
(
self
.
data_length
/
batch_size
)
if
labels
:
self
.
labels
=
np
.
array
(
dataset
[
"label"
])
assert
len
(
self
.
labels
)
==
self
.
data_length
,
"Labels not the same length as input arrays!"
else
:
self
.
labels
=
None
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
if
self
.
shuffle
:
# Shuffle the data order
self
.
permutation
=
np
.
random
.
permutation
(
self
.
data_length
)
else
:
self
.
permutation
=
None
def
on_epoch_end
(
self
):
# If we're shuffling, reshuffle the data order after each epoch
if
self
.
shuffle
:
self
.
permutation
=
np
.
random
.
permutation
(
self
.
data_length
)
def
__getitem__
(
self
,
item
):
# Note that this yields a batch, not a single sample
batch_start
=
item
*
self
.
batch_size
batch_end
=
(
item
+
1
)
*
self
.
batch_size
if
self
.
shuffle
:
data_indices
=
self
.
permutation
[
batch_start
:
batch_end
]
else
:
data_indices
=
np
.
arange
(
batch_start
,
batch_end
)
# We want to pad the data as little as possible, so we only pad each batch
# to the maximum length within that batch. We do that by stacking the variable-
# length inputs into a ragged tensor and then densifying it.
batch_input
=
{
key
:
tf
.
ragged
.
constant
([
data
[
i
]
for
i
in
data_indices
]).
to_tensor
()
for
key
,
data
in
self
.
data
.
items
()
}
if
self
.
labels
is
None
:
return
batch_input
else
:
batch_labels
=
self
.
labels
[
data_indices
]
return
batch_input
,
batch_labels
def
__len__
(
self
):
return
self
.
num_batches
class
SavePretrainedCallback
(
tf
.
keras
.
callbacks
.
Callback
):
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
...
...
@@ -119,8 +58,50 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
self
.
model
.
save_pretrained
(
self
.
output_dir
)
def
convert_dataset_for_tensorflow
(
dataset
,
non_label_column_names
,
batch_size
,
dataset_mode
=
"variable_batch"
,
shuffle
=
True
,
drop_remainder
=
True
):
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
"""
def
densify_ragged_batch
(
features
,
label
=
None
):
features
=
{
feature
:
ragged_tensor
.
to_tensor
(
shape
=
batch_shape
[
feature
])
for
feature
,
ragged_tensor
in
features
.
items
()
}
if
label
is
None
:
return
features
else
:
return
features
,
label
feature_keys
=
list
(
set
(
dataset
.
features
.
keys
())
-
set
(
non_label_column_names
+
[
"label"
]))
if
dataset_mode
==
"variable_batch"
:
batch_shape
=
{
key
:
None
for
key
in
feature_keys
}
data
=
{
key
:
tf
.
ragged
.
constant
(
dataset
[
key
])
for
key
in
feature_keys
}
elif
dataset_mode
==
"constant_batch"
:
data
=
{
key
:
tf
.
ragged
.
constant
(
dataset
[
key
])
for
key
in
feature_keys
}
batch_shape
=
{
key
:
tf
.
concat
(([
batch_size
],
ragged_tensor
.
bounding_shape
()[
1
:]),
axis
=
0
)
for
key
,
ragged_tensor
in
data
.
items
()
}
else
:
raise
ValueError
(
"Unknown dataset mode!"
)
if
"label"
in
dataset
.
features
:
labels
=
tf
.
convert_to_tensor
(
np
.
array
(
dataset
[
"label"
]))
tf_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
data
,
labels
))
else
:
tf_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
data
)
if
shuffle
:
tf_dataset
=
tf_dataset
.
shuffle
(
buffer_size
=
len
(
dataset
))
tf_dataset
=
tf_dataset
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
).
map
(
densify_ragged_batch
)
return
tf_dataset
# endregion
# region Command-line arguments
@
dataclass
class
DataTrainingArguments
:
...
...
@@ -155,6 +136,7 @@ class DataTrainingArguments:
metadata
=
{
"help"
:
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
"Data will always be padded when using TPUs."
},
)
max_train_samples
:
Optional
[
int
]
=
field
(
...
...
@@ -164,17 +146,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_
e
val_samples
:
Optional
[
int
]
=
field
(
max_val_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of
e
val
u
ation examples to this "
"help"
:
"For debugging purposes or quicker training, truncate the number of val
id
ation examples to this "
"value if set."
},
)
max_
predic
t_samples
:
Optional
[
int
]
=
field
(
max_
tes
t_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of
predic
t examples to this "
"help"
:
"For debugging purposes or quicker training, truncate the number of
tes
t examples to this "
"value if set."
},
)
...
...
@@ -223,6 +205,7 @@ class ModelArguments:
"with private models)."
},
)
tpu
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Name of the TPU resource to use, if available"
})
# endregion
...
...
@@ -234,7 +217,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser
=
HfArgumentParser
((
ModelArguments
,
DataTrainingArguments
,
TrainingArguments
))
parser
=
HfArgumentParser
((
ModelArguments
,
DataTrainingArguments
,
TF
TrainingArguments
))
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
...
...
@@ -322,12 +305,7 @@ def main():
is_regression
=
None
# endregion
# region Load pretrained model and tokenizer
# Set seed before initializing model
set_seed
(
training_args
.
seed
)
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
# region Load model config and tokenizer
if
checkpoint
is
not
None
:
config_path
=
training_args
.
output_dir
elif
model_args
.
config_name
:
...
...
@@ -355,34 +333,6 @@ def main():
revision
=
model_args
.
model_revision
,
use_auth_token
=
True
if
model_args
.
use_auth_token
else
None
,
)
if
checkpoint
is
None
:
model_path
=
model_args
.
model_name_or_path
else
:
model_path
=
checkpoint
model
=
TFAutoModelForSequenceClassification
.
from_pretrained
(
model_path
,
config
=
config
,
cache_dir
=
model_args
.
cache_dir
,
revision
=
model_args
.
model_revision
,
use_auth_token
=
True
if
model_args
.
use_auth_token
else
None
,
)
# endregion
# region Optimizer, loss and compilation
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
training_args
.
learning_rate
,
beta_1
=
training_args
.
adam_beta1
,
beta_2
=
training_args
.
adam_beta2
,
epsilon
=
training_args
.
adam_epsilon
,
clipnorm
=
training_args
.
max_grad_norm
,
)
if
is_regression
:
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
()
metrics
=
[]
else
:
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
metrics
=
[
"accuracy"
]
model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss
,
metrics
=
metrics
)
# endregion
# region Dataset preprocessing
...
...
@@ -399,13 +349,6 @@ def main():
else
:
sentence1_key
,
sentence2_key
=
non_label_column_names
[
0
],
None
# Padding strategy
if
data_args
.
pad_to_max_length
:
padding
=
"max_length"
else
:
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
padding
=
False
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warning
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
...
...
@@ -415,8 +358,8 @@ def main():
# Ensure that our labels match the model's, if it has some pre-specified
if
"train"
in
datasets
:
if
not
is_regression
and
model
.
config
.
label2id
!=
PretrainedConfig
(
num_labels
=
num_labels
).
label2id
:
label_name_to_id
=
model
.
config
.
label2id
if
not
is_regression
and
config
.
label2id
!=
PretrainedConfig
(
num_labels
=
num_labels
).
label2id
:
label_name_to_id
=
config
.
label2id
if
list
(
sorted
(
label_name_to_id
.
keys
()))
==
list
(
sorted
(
label_list
)):
label_to_id
=
label_name_to_id
# Use the model's labels
else
:
...
...
@@ -431,15 +374,15 @@ def main():
else
:
label_to_id
=
None
# Now we've established our label2id, let's overwrite the model config with it.
model
.
config
.
label2id
=
label_to_id
if
model
.
config
.
label2id
is
not
None
:
model
.
config
.
id2label
=
{
id
:
label
for
label
,
id
in
label_to_id
.
items
()}
config
.
label2id
=
label_to_id
if
config
.
label2id
is
not
None
:
config
.
id2label
=
{
id
:
label
for
label
,
id
in
label_to_id
.
items
()}
else
:
model
.
config
.
id2label
=
None
config
.
id2label
=
None
else
:
label_to_id
=
model
.
config
.
label2id
# Just load the data from the model
label_to_id
=
config
.
label2id
# Just load the data from the model
if
"validation"
in
datasets
and
model
.
config
.
label2id
is
not
None
:
if
"validation"
in
datasets
and
config
.
label2id
is
not
None
:
validation_label_list
=
datasets
[
"validation"
].
unique
(
"label"
)
for
val_label
in
validation_label_list
:
assert
val_label
in
label_to_id
,
f
"Label
{
val_label
}
is in the validation set but not the training set!"
...
...
@@ -449,87 +392,141 @@ def main():
args
=
(
(
examples
[
sentence1_key
],)
if
sentence2_key
is
None
else
(
examples
[
sentence1_key
],
examples
[
sentence2_key
])
)
result
=
tokenizer
(
*
args
,
padding
=
padding
,
max_length
=
max_seq_length
,
truncation
=
True
)
result
=
tokenizer
(
*
args
,
max_length
=
max_seq_length
,
truncation
=
True
)
# Map labels to IDs
if
model
.
config
.
label2id
is
not
None
and
"label"
in
examples
:
result
[
"label"
]
=
[(
model
.
config
.
label2id
[
l
]
if
l
!=
-
1
else
-
1
)
for
l
in
examples
[
"label"
]]
if
config
.
label2id
is
not
None
and
"label"
in
examples
:
result
[
"label"
]
=
[(
config
.
label2id
[
l
]
if
l
!=
-
1
else
-
1
)
for
l
in
examples
[
"label"
]]
return
result
datasets
=
datasets
.
map
(
preprocess_function
,
batched
=
True
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
)
if
"train"
in
datasets
:
train_dataset
=
datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
# Log a few random samples from the training set so we can see that it's working as expected:
for
index
in
random
.
sample
(
range
(
len
(
train_dataset
)),
3
):
logger
.
info
(
f
"Sample
{
index
}
of the training set:
{
train_dataset
[
index
]
}
."
)
if
"validation"
in
datasets
:
eval_dataset
=
datasets
[
"validation"
]
if
data_args
.
max_eval_samples
is
not
None
:
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_eval_samples
))
if
"test"
in
datasets
:
predict_dataset
=
datasets
[
"test"
]
if
data_args
.
max_predict_samples
is
not
None
:
predict_dataset
=
predict_dataset
.
select
(
range
(
data_args
.
max_predict_samples
))
# endregion
# region Training
if
"train"
in
datasets
:
training_dataset
=
DataSequence
(
train_dataset
,
non_label_column_names
,
batch_size
=
training_args
.
per_device_train_batch_size
,
labels
=
True
)
if
"validation"
in
datasets
:
eval_dataset
=
DataSequence
(
eval_dataset
,
non_label_column_names
,
batch_size
=
training_args
.
per_device_eval_batch_size
,
labels
=
True
)
with
training_args
.
strategy
.
scope
():
# region Load pretrained model
# Set seed before initializing model
set_seed
(
training_args
.
seed
)
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if
checkpoint
is
None
:
model_path
=
model_args
.
model_name_or_path
else
:
eval_dataset
=
None
callbacks
=
[
SavePretrainedCallback
(
output_dir
=
training_args
.
output_dir
)]
model
.
fit
(
training_dataset
,
validation_data
=
eval_dataset
,
epochs
=
int
(
training_args
.
num_train_epochs
),
callbacks
=
callbacks
,
model_path
=
checkpoint
model
=
TFAutoModelForSequenceClassification
.
from_pretrained
(
model_path
,
config
=
config
,
cache_dir
=
model_args
.
cache_dir
,
revision
=
model_args
.
model_revision
,
use_auth_token
=
True
if
model_args
.
use_auth_token
else
None
,
)
elif
"validation"
in
datasets
:
# If there's a validation dataset but no training set, just evaluate the metrics
eval_dataset
=
DataSequence
(
eval_dataset
,
non_label_column_names
,
batch_size
=
training_args
.
per_device_eval_batch_size
,
labels
=
True
# endregion
# region Optimizer, loss and compilation
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
training_args
.
learning_rate
,
beta_1
=
training_args
.
adam_beta1
,
beta_2
=
training_args
.
adam_beta2
,
epsilon
=
training_args
.
adam_epsilon
,
clipnorm
=
training_args
.
max_grad_norm
,
)
logger
.
info
(
"Computing metrics on validation data..."
)
if
is_regression
:
loss
=
model
.
evaluate
(
eval_dataset
)
logger
.
info
(
f
"Loss:
{
loss
:.
5
f
}
"
)
loss
_fn
=
tf
.
keras
.
losses
.
MeanSquaredError
(
)
metrics
=
[]
else
:
loss
,
accuracy
=
model
.
evaluate
(
eval_dataset
)
logger
.
info
(
f
"Loss:
{
loss
:.
5
f
}
, Accuracy:
{
accuracy
*
100
:.
4
f
}
%"
)
# endregion
# region Prediction
if
"test"
in
datasets
:
logger
.
info
(
"Doing predictions on Predict dataset..."
)
predict_dataset
=
DataSequence
(
predict_dataset
,
non_label_column_names
,
batch_size
=
training_args
.
per_device_eval_batch_size
,
labels
=
False
)
predictions
=
model
.
predict
(
predict_dataset
)[
"logits"
]
predictions
=
np
.
squeeze
(
predictions
)
if
is_regression
else
np
.
argmax
(
predictions
,
axis
=
1
)
output_predict_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"predict_results.txt"
)
with
open
(
output_predict_file
,
"w"
)
as
writer
:
writer
.
write
(
"index
\t
prediction
\n
"
)
for
index
,
item
in
enumerate
(
predictions
):
if
is_regression
:
writer
.
write
(
f
"
{
index
}
\t
{
item
:
3.3
f
}
\n
"
)
else
:
item
=
model
.
config
.
id2label
[
item
]
writer
.
write
(
f
"
{
index
}
\t
{
item
}
\n
"
)
logger
.
info
(
f
"Wrote predictions to
{
output_predict_file
}
!"
)
loss_fn
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
metrics
=
[
"accuracy"
]
model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
metrics
)
# endregion
# region Convert data to TF format
# Convert data to a tf.keras.utils.Sequence object for training if we're not using a TPU
# For TPU, convert to a tf.data.Dataset
tf_data
=
dict
()
max_samples
=
{
"train"
:
data_args
.
max_train_samples
,
"validation"
:
data_args
.
max_val_samples
,
"test"
:
data_args
.
max_test_samples
,
}
for
key
in
(
"train"
,
"validation"
,
"test"
):
if
key
not
in
datasets
:
tf_data
[
key
]
=
None
continue
if
key
in
(
"train"
,
"validation"
):
assert
"label"
in
datasets
[
key
].
features
,
f
"Missing labels from
{
key
}
data!"
if
key
==
"train"
:
shuffle
=
True
batch_size
=
training_args
.
per_device_train_batch_size
drop_remainder
=
True
# Saves us worrying about scaling gradients for the last batch
else
:
shuffle
=
False
batch_size
=
training_args
.
per_device_eval_batch_size
drop_remainder
=
False
samples_limit
=
max_samples
[
key
]
dataset
=
datasets
[
key
]
if
samples_limit
is
not
None
:
dataset
=
dataset
.
select
(
range
(
samples_limit
))
if
isinstance
(
training_args
.
strategy
,
tf
.
distribute
.
TPUStrategy
)
or
data_args
.
pad_to_max_length
:
logger
.
info
(
"Padding all batches to max length because argument was set or we're on TPU."
)
dataset_mode
=
"constant_batch"
else
:
dataset_mode
=
"variable_batch"
data
=
convert_dataset_for_tensorflow
(
dataset
,
non_label_column_names
,
batch_size
=
batch_size
,
dataset_mode
=
dataset_mode
,
drop_remainder
=
drop_remainder
,
shuffle
=
shuffle
,
)
tf_data
[
key
]
=
data
# endregion
# region Training and validation
if
tf_data
[
"train"
]
is
not
None
:
callbacks
=
[
SavePretrainedCallback
(
output_dir
=
training_args
.
output_dir
)]
model
.
fit
(
tf_data
[
"train"
],
validation_data
=
tf_data
[
"validation"
],
epochs
=
int
(
training_args
.
num_train_epochs
),
callbacks
=
callbacks
,
)
elif
tf_data
[
"validation"
]
is
not
None
:
# If there's a validation dataset but no training set, just evaluate the metrics
logger
.
info
(
"Computing metrics on validation data..."
)
if
is_regression
:
loss
=
model
.
evaluate
(
tf_data
[
"validation"
])
logger
.
info
(
f
"Loss:
{
loss
:.
5
f
}
"
)
else
:
loss
,
accuracy
=
model
.
evaluate
(
tf_data
[
"validation"
])
logger
.
info
(
f
"Loss:
{
loss
:.
5
f
}
, Accuracy:
{
accuracy
*
100
:.
4
f
}
%"
)
# endregion
# region Prediction
if
tf_data
[
"test"
]
is
not
None
:
logger
.
info
(
"Doing predictions on test dataset..."
)
predictions
=
model
.
predict
(
tf_data
[
"test"
])[
"logits"
]
predicted_class
=
np
.
squeeze
(
predictions
)
if
is_regression
else
np
.
argmax
(
predictions
,
axis
=
1
)
output_test_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_file
,
"w"
)
as
writer
:
writer
.
write
(
"index
\t
prediction
\n
"
)
for
index
,
item
in
enumerate
(
predicted_class
):
if
is_regression
:
writer
.
write
(
f
"
{
index
}
\t
{
item
:
3.3
f
}
\n
"
)
else
:
item
=
config
.
id2label
[
item
]
writer
.
write
(
f
"
{
index
}
\t
{
item
}
\n
"
)
logger
.
info
(
f
"Wrote predictions to
{
output_test_file
}
!"
)
# endregion
# region Prediction losses
# This section is outside the scope() because it's very quick to compute, but behaves badly inside it
if
"label"
in
datasets
[
"test"
].
features
:
print
(
"Computing prediction loss on test labels..."
)
labels
=
datasets
[
"test"
][
"label"
]
loss
=
float
(
loss_fn
(
labels
,
predictions
).
numpy
())
print
(
f
"Test loss:
{
loss
:.
4
f
}
"
)
# endregion
...
...
src/transformers/training_args_tf.py
View file @
20d6931e
...
...
@@ -212,7 +212,10 @@ class TFTrainingArguments(TrainingArguments):
else
:
tpu
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
()
except
ValueError
:
tpu
=
None
if
self
.
tpu_name
:
raise
RuntimeError
(
f
"Couldn't connect to TPU
{
self
.
tpu_name
}
!"
)
else
:
tpu
=
None
if
tpu
:
# Set to bfloat16 in case of TPU
...
...
@@ -233,7 +236,7 @@ class TFTrainingArguments(TrainingArguments):
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
strategy
=
tf
.
distribute
.
MirroredStrategy
()
else
:
raise
ValueError
(
"Cannot find the proper strategy please check your environment properties."
)
raise
ValueError
(
"Cannot find the proper strategy
,
please check your environment properties."
)
return
strategy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment