Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
04c69db3
Unverified
Commit
04c69db3
authored
Oct 30, 2019
by
Thomas Wolf
Committed by
GitHub
Oct 30, 2019
Browse files
Merge pull request #1628 from huggingface/tfglue
run_tf_glue works with all tasks
parents
5c6a19a9
beaf66b1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
8 deletions
+43
-8
examples/run_tf_glue.py
examples/run_tf_glue.py
+35
-8
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+1
-0
transformers/data/processors/utils.py
transformers/data/processors/utils.py
+7
-0
No files found.
examples/run_tf_glue.py
View file @
04c69db3
import
os
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
import
tensorflow_datasets
from
transformers
import
BertTokenizer
,
TFBertForSequenceClassification
,
glue_convert_examples_to_features
,
BertForSequenceClassification
from
transformers
import
BertTokenizer
,
TFBertForSequenceClassification
,
BertConfig
,
glue_convert_examples_to_features
,
BertForSequenceClassification
,
glue_processors
# script parameters
# script parameters
BATCH_SIZE
=
32
BATCH_SIZE
=
32
EVAL_BATCH_SIZE
=
BATCH_SIZE
*
2
EVAL_BATCH_SIZE
=
BATCH_SIZE
*
2
USE_XLA
=
False
USE_XLA
=
False
USE_AMP
=
False
USE_AMP
=
False
EPOCHS
=
3
TASK
=
"mrpc"
if
TASK
==
"sst-2"
:
TFDS_TASK
=
"sst2"
elif
TASK
==
"sts-b"
:
TFDS_TASK
=
"stsb"
else
:
TFDS_TASK
=
TASK
num_labels
=
len
(
glue_processors
[
TASK
]().
get_labels
())
print
(
num_labels
)
tf
.
config
.
optimizer
.
set_jit
(
USE_XLA
)
tf
.
config
.
optimizer
.
set_jit
(
USE_XLA
)
tf
.
config
.
optimizer
.
set_experimental_options
({
"auto_mixed_precision"
:
USE_AMP
})
tf
.
config
.
optimizer
.
set_experimental_options
({
"auto_mixed_precision"
:
USE_AMP
})
# Load tokenizer and model from pretrained model/vocabulary
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config
=
BertConfig
.
from_pretrained
(
"bert-base-cased"
,
num_labels
=
num_labels
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
'bert-base-cased'
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
'bert-base-cased'
)
model
=
TFBertForSequenceClassification
.
from_pretrained
(
'bert-base-cased'
)
model
=
TFBertForSequenceClassification
.
from_pretrained
(
'bert-base-cased'
,
config
=
config
)
# Load dataset via TensorFlow Datasets
# Load dataset via TensorFlow Datasets
data
,
info
=
tensorflow_datasets
.
load
(
'glue/
mrpc
'
,
with_info
=
True
)
data
,
info
=
tensorflow_datasets
.
load
(
f
'glue/
{
TFDS_TASK
}
'
,
with_info
=
True
)
train_examples
=
info
.
splits
[
'train'
].
num_examples
train_examples
=
info
.
splits
[
'train'
].
num_examples
# MNLI expects either validation_matched or validation_mismatched
valid_examples
=
info
.
splits
[
'validation'
].
num_examples
valid_examples
=
info
.
splits
[
'validation'
].
num_examples
# Prepare dataset for GLUE as a tf.data.Dataset instance
# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset
=
glue_convert_examples_to_features
(
data
[
'train'
],
tokenizer
,
128
,
'mrpc'
)
train_dataset
=
glue_convert_examples_to_features
(
data
[
'train'
],
tokenizer
,
128
,
TASK
)
valid_dataset
=
glue_convert_examples_to_features
(
data
[
'validation'
],
tokenizer
,
128
,
'mrpc'
)
# MNLI expects either validation_matched or validation_mismatched
valid_dataset
=
glue_convert_examples_to_features
(
data
[
'validation'
],
tokenizer
,
128
,
TASK
)
train_dataset
=
train_dataset
.
shuffle
(
128
).
batch
(
BATCH_SIZE
).
repeat
(
-
1
)
train_dataset
=
train_dataset
.
shuffle
(
128
).
batch
(
BATCH_SIZE
).
repeat
(
-
1
)
valid_dataset
=
valid_dataset
.
batch
(
EVAL_BATCH_SIZE
)
valid_dataset
=
valid_dataset
.
batch
(
EVAL_BATCH_SIZE
)
...
@@ -32,7 +50,13 @@ opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
...
@@ -32,7 +50,13 @@ opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if
USE_AMP
:
if
USE_AMP
:
# loss scaling is currently required when using mixed precision
# loss scaling is currently required when using mixed precision
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
'dynamic'
)
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
'dynamic'
)
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
if
num_labels
==
1
:
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
()
else
:
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
metric
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'accuracy'
)
metric
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'accuracy'
)
model
.
compile
(
optimizer
=
opt
,
loss
=
loss
,
metrics
=
[
metric
])
model
.
compile
(
optimizer
=
opt
,
loss
=
loss
,
metrics
=
[
metric
])
...
@@ -40,7 +64,7 @@ model.compile(optimizer=opt, loss=loss, metrics=[metric])
...
@@ -40,7 +64,7 @@ model.compile(optimizer=opt, loss=loss, metrics=[metric])
train_steps
=
train_examples
//
BATCH_SIZE
train_steps
=
train_examples
//
BATCH_SIZE
valid_steps
=
valid_examples
//
EVAL_BATCH_SIZE
valid_steps
=
valid_examples
//
EVAL_BATCH_SIZE
history
=
model
.
fit
(
train_dataset
,
epochs
=
2
,
steps_per_epoch
=
train_steps
,
history
=
model
.
fit
(
train_dataset
,
epochs
=
EPOCHS
,
steps_per_epoch
=
train_steps
,
validation_data
=
valid_dataset
,
validation_steps
=
valid_steps
)
validation_data
=
valid_dataset
,
validation_steps
=
valid_steps
)
# Save TF2 model
# Save TF2 model
...
@@ -57,6 +81,9 @@ sentence_2 = 'His findings were not compatible with this research.'
...
@@ -57,6 +81,9 @@ sentence_2 = 'His findings were not compatible with this research.'
inputs_1
=
tokenizer
.
encode_plus
(
sentence_0
,
sentence_1
,
add_special_tokens
=
True
,
return_tensors
=
'pt'
)
inputs_1
=
tokenizer
.
encode_plus
(
sentence_0
,
sentence_1
,
add_special_tokens
=
True
,
return_tensors
=
'pt'
)
inputs_2
=
tokenizer
.
encode_plus
(
sentence_0
,
sentence_2
,
add_special_tokens
=
True
,
return_tensors
=
'pt'
)
inputs_2
=
tokenizer
.
encode_plus
(
sentence_0
,
sentence_2
,
add_special_tokens
=
True
,
return_tensors
=
'pt'
)
del
inputs_1
[
"special_tokens_mask"
]
del
inputs_2
[
"special_tokens_mask"
]
pred_1
=
pytorch_model
(
**
inputs_1
)[
0
].
argmax
().
item
()
pred_1
=
pytorch_model
(
**
inputs_1
)[
0
].
argmax
().
item
()
pred_2
=
pytorch_model
(
**
inputs_2
)[
0
].
argmax
().
item
()
pred_2
=
pytorch_model
(
**
inputs_2
)[
0
].
argmax
().
item
()
print
(
'sentence_1 is'
,
'a paraphrase'
if
pred_1
else
'not a paraphrase'
,
'of sentence_0'
)
print
(
'sentence_1 is'
,
'a paraphrase'
if
pred_1
else
'not a paraphrase'
,
'of sentence_0'
)
...
...
transformers/data/processors/glue.py
View file @
04c69db3
...
@@ -80,6 +80,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
...
@@ -80,6 +80,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
logger
.
info
(
"Writing example %d"
%
(
ex_index
))
logger
.
info
(
"Writing example %d"
%
(
ex_index
))
if
is_tf_dataset
:
if
is_tf_dataset
:
example
=
processor
.
get_example_from_tensor_dict
(
example
)
example
=
processor
.
get_example_from_tensor_dict
(
example
)
example
=
processor
.
tfds_map
(
example
)
inputs
=
tokenizer
.
encode_plus
(
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_a
,
...
...
transformers/data/processors/utils.py
View file @
04c69db3
...
@@ -107,6 +107,13 @@ class DataProcessor(object):
...
@@ -107,6 +107,13 @@ class DataProcessor(object):
"""Gets the list of labels for this data set."""
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
tfds_map
(
self
,
example
):
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
This method converts examples to the correct format."""
if
len
(
self
.
get_labels
())
>
1
:
example
.
label
=
self
.
get_labels
()[
int
(
example
.
label
)]
return
example
@
classmethod
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
"""Reads a tab separated value file."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment