Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e4022d96
Unverified
Commit
e4022d96
authored
Sep 24, 2019
by
Thomas Wolf
Committed by
GitHub
Sep 24, 2019
Browse files
Merge pull request #1325 from huggingface/glue-included
[Proposal] GLUE processors included in library
parents
a6981076
789ea720
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
252 additions
and
209 deletions
+252
-209
.gitignore
.gitignore
+1
-1
examples/run_glue.py
examples/run_glue.py
+4
-2
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+7
-0
pytorch_transformers/data/__init__.py
pytorch_transformers/data/__init__.py
+6
-0
pytorch_transformers/data/metrics/__init__.py
pytorch_transformers/data/metrics/__init__.py
+83
-0
pytorch_transformers/data/processors/__init__.py
pytorch_transformers/data/processors/__init__.py
+3
-0
pytorch_transformers/data/processors/glue.py
pytorch_transformers/data/processors/glue.py
+73
-206
pytorch_transformers/data/processors/utils.py
pytorch_transformers/data/processors/utils.py
+75
-0
No files found.
.gitignore
View file @
e4022d96
...
...
@@ -130,5 +130,5 @@ runs
examples/runs
# data
data
/
data
serialization_dir
\ No newline at end of file
examples/run_glue.py
View file @
e4022d96
...
...
@@ -46,8 +46,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
utils_glue
import
(
compute_metrics
,
convert_examples_to_features
,
output_modes
,
processors
)
from
pytorch_transformers
import
glue_compute_metrics
as
compute_metrics
from
pytorch_transformers
import
glue_output_modes
as
output_modes
from
pytorch_transformers
import
glue_processors
as
processors
from
pytorch_transformers
import
glue_convert_examples_to_features
as
convert_examples_to_features
logger
=
logging
.
getLogger
(
__name__
)
...
...
pytorch_transformers/__init__.py
View file @
e4022d96
...
...
@@ -73,3 +73,10 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
from
.data
import
(
is_sklearn_available
,
InputExample
,
InputFeatures
,
DataProcessor
,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_processors
,
glue_tasks_num_labels
)
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
pytorch_transformers/data/__init__.py
0 → 100644
View file @
e4022d96
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.metrics
import
is_sklearn_available
if
is_sklearn_available
():
from
.metrics
import
glue_compute_metrics
pytorch_transformers/data/metrics/__init__.py
0 → 100644
View file @
e4022d96
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
sys
import
logging
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
_has_sklearn
=
True
except
(
AttributeError
,
ImportError
)
as
e
:
logger
.
warning
(
"To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html"
)
_has_sklearn
=
False
def
is_sklearn_available
():
return
_has_sklearn
if
_has_sklearn
:
def
simple_accuracy
(
preds
,
labels
):
return
(
preds
==
labels
).
mean
()
def
acc_and_f1
(
preds
,
labels
):
acc
=
simple_accuracy
(
preds
,
labels
)
f1
=
f1_score
(
y_true
=
labels
,
y_pred
=
preds
)
return
{
"acc"
:
acc
,
"f1"
:
f1
,
"acc_and_f1"
:
(
acc
+
f1
)
/
2
,
}
def
pearson_and_spearman
(
preds
,
labels
):
pearson_corr
=
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
spearmanr
(
preds
,
labels
)[
0
]
return
{
"pearson"
:
pearson_corr
,
"spearmanr"
:
spearman_corr
,
"corr"
:
(
pearson_corr
+
spearman_corr
)
/
2
,
}
def
glue_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"cola"
:
return
{
"mcc"
:
matthews_corrcoef
(
labels
,
preds
)}
elif
task_name
==
"sst-2"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mrpc"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"sts-b"
:
return
pearson_and_spearman
(
preds
,
labels
)
elif
task_name
==
"qqp"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"mnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mnli-mm"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"qnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"rte"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"wnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
pytorch_transformers/data/processors/__init__.py
0 → 100644
View file @
e4022d96
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
examples/utils_
glue.py
→
pytorch_transformers/data/processors/
glue.py
View file @
e4022d96
...
...
@@ -13,79 +13,81 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
BERT classification fine-tuning: utilities to work with GLUE task
s """
"""
GLUE processors and helper
s """
from
__future__
import
absolute_import
,
division
,
print_function
import
csv
import
logging
import
os
import
sys
from
io
import
open
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
logger
=
logging
.
getLogger
(
__name__
)
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
def
glue_convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
"""
Loads a data file into a list of `InputBatch`s
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
class
InputFeatures
(
object
):
"""A single set of features of data."""
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
max_length
=
max_seq_length
,
truncate_first_sequence
=
True
# We're truncating the first sequence as a priority
)
input_ids
,
segment_ids
=
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
# Zero-pad up to the sequence length.
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
if
output_mode
==
"classification"
:
label_id
=
label_map
[
example
.
label
]
elif
output_mode
==
"regression"
:
label_id
=
float
(
example
.
label
)
else
:
raise
KeyError
(
output_mode
)
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
if
ex_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid: %s"
%
(
example
.
guid
))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
))
return
features
class
MrpcProcessor
(
DataProcessor
):
...
...
@@ -387,142 +389,19 @@ class WnliProcessor(DataProcessor):
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
glue_tasks_num_labels
=
{
"cola"
:
2
,
"mnli"
:
3
,
"mrpc"
:
2
,
"sst-2"
:
2
,
"sts-b"
:
1
,
"qqp"
:
2
,
"qnli"
:
2
,
"rte"
:
2
,
"wnli"
:
2
,
}
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
"""
Loads a data file into a list of `InputBatch`s
"""
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
max_length
=
max_seq_length
,
truncate_first_sequence
=
True
# We're truncating the first sequence as a priority
)
input_ids
,
segment_ids
=
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
if
output_mode
==
"classification"
:
label_id
=
label_map
[
example
.
label
]
elif
output_mode
==
"regression"
:
label_id
=
float
(
example
.
label
)
else
:
raise
KeyError
(
output_mode
)
if
ex_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid: %s"
%
(
example
.
guid
))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
))
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
tokens_b
.
pop
()
def
simple_accuracy
(
preds
,
labels
):
return
(
preds
==
labels
).
mean
()
def
acc_and_f1
(
preds
,
labels
):
acc
=
simple_accuracy
(
preds
,
labels
)
f1
=
f1_score
(
y_true
=
labels
,
y_pred
=
preds
)
return
{
"acc"
:
acc
,
"f1"
:
f1
,
"acc_and_f1"
:
(
acc
+
f1
)
/
2
,
}
def
pearson_and_spearman
(
preds
,
labels
):
pearson_corr
=
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
spearmanr
(
preds
,
labels
)[
0
]
return
{
"pearson"
:
pearson_corr
,
"spearmanr"
:
spearman_corr
,
"corr"
:
(
pearson_corr
+
spearman_corr
)
/
2
,
}
def
compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"cola"
:
return
{
"mcc"
:
matthews_corrcoef
(
labels
,
preds
)}
elif
task_name
==
"sst-2"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mrpc"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"sts-b"
:
return
pearson_and_spearman
(
preds
,
labels
)
elif
task_name
==
"qqp"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"mnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mnli-mm"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"qnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"rte"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"wnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
processors
=
{
glue_processors
=
{
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli-mm"
:
MnliMismatchedProcessor
,
...
...
@@ -535,7 +414,7 @@ processors = {
"wnli"
:
WnliProcessor
,
}
output_modes
=
{
glue_
output_modes
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli-mm"
:
"classification"
,
...
...
@@ -547,15 +426,3 @@ output_modes = {
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
}
GLUE_TASKS_NUM_LABELS
=
{
"cola"
:
2
,
"mnli"
:
3
,
"mrpc"
:
2
,
"sst-2"
:
2
,
"sts-b"
:
1
,
"qqp"
:
2
,
"qnli"
:
2
,
"rte"
:
2
,
"wnli"
:
2
,
}
pytorch_transformers/data/processors/utils.py
0 → 100644
View file @
e4022d96
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
sys
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
class
InputFeatures
(
object
):
"""A single set of features of data."""
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment