Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f09e5ece
Commit
f09e5ece
authored
Sep 24, 2019
by
LysandreJik
Browse files
[Proposal] GLUE processors included in library
parent
72402d1a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
204 deletions
+230
-204
examples/run_glue.py
examples/run_glue.py
+2
-3
pytorch_transformers/preprocessing/__init__.py
pytorch_transformers/preprocessing/__init__.py
+56
-0
pytorch_transformers/preprocessing/glue.py
pytorch_transformers/preprocessing/glue.py
+73
-201
pytorch_transformers/preprocessing/utils.py
pytorch_transformers/preprocessing/utils.py
+99
-0
No files found.
examples/run_glue.py
View file @
f09e5ece
...
@@ -46,8 +46,7 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
...
@@ -46,8 +46,7 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
utils_glue
import
(
compute_metrics
,
convert_examples_to_features
,
from
pytorch_transformers.preprocessing
import
(
compute_metrics
,
output_modes
,
processors
,
convert_examples_to_glue_features
)
output_modes
,
processors
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -276,7 +275,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
...
@@ -276,7 +275,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# HACK(label indices are swapped in RoBERTa pretrained model)
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list
[
1
],
label_list
[
2
]
=
label_list
[
2
],
label_list
[
1
]
label_list
[
1
],
label_list
[
2
]
=
label_list
[
2
],
label_list
[
1
]
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
features
=
convert_examples_to_features
(
examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
features
=
convert_examples_to_
glue_
features
(
examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
pad_on_left
=
bool
(
args
.
model_type
in
[
'xlnet'
]),
# pad on the left for xlnet
pad_on_left
=
bool
(
args
.
model_type
in
[
'xlnet'
]),
# pad on the left for xlnet
pad_token
=
tokenizer
.
convert_tokens_to_ids
([
tokenizer
.
pad_token
])[
0
],
pad_token
=
tokenizer
.
convert_tokens_to_ids
([
tokenizer
.
pad_token
])[
0
],
pad_token_segment_id
=
4
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
pad_token_segment_id
=
4
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
...
...
pytorch_transformers/preprocessing/__init__.py
0 → 100644
View file @
f09e5ece
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
glue
import
(
ColaProcessor
,
MnliProcessor
,
MnliMismatchedProcessor
,
MrpcProcessor
,
Sst2Processor
,
StsbProcessor
,
QqpProcessor
,
QnliProcessor
,
RteProcessor
,
WnliProcessor
,
convert_examples_to_glue_features
,
)
from
utils
import
DataProcessor
,
simple_accuracy
,
acc_and_f1
,
pearson_and_spearman
,
compute_metrics
processors
=
{
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli-mm"
:
MnliMismatchedProcessor
,
"mrpc"
:
MrpcProcessor
,
"sst-2"
:
Sst2Processor
,
"sts-b"
:
StsbProcessor
,
"qqp"
:
QqpProcessor
,
"qnli"
:
QnliProcessor
,
"rte"
:
RteProcessor
,
"wnli"
:
WnliProcessor
,
}
output_modes
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli-mm"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
"qqp"
:
"classification"
,
"qnli"
:
"classification"
,
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
}
examples/utils_
glue.py
→
pytorch_transformers/preprocessing/
glue.py
View file @
f09e5ece
...
@@ -13,22 +13,84 @@
...
@@ -13,22 +13,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
BERT classification fine-tuning: utilities to work with GLUE task
s """
"""
GLUE processors and helper
s """
from
__future__
import
absolute_import
,
division
,
print_function
from
utils
import
DataProcessor
import
csv
import
logging
import
logging
import
os
import
os
import
sys
from
io
import
open
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
convert_examples_to_glue_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
"""
Loads a data file into a list of `InputBatch`s
"""
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
output_token_type
=
True
,
max_length
=
max_seq_length
,
truncate_first_sequence
=
True
# We're truncating the first sequence as a priority
)
input_ids
,
segment_ids
=
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
if
output_mode
==
"classification"
:
label_id
=
label_map
[
example
.
label
]
elif
output_mode
==
"regression"
:
label_id
=
float
(
example
.
label
)
else
:
raise
KeyError
(
output_mode
)
if
ex_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid: %s"
%
(
example
.
guid
))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
))
return
features
class
InputExample
(
object
):
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple sequence classification."""
...
@@ -60,34 +122,6 @@ class InputFeatures(object):
...
@@ -60,34 +122,6 @@ class InputFeatures(object):
self
.
label_id
=
label_id
self
.
label_id
=
label_id
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
class
MrpcProcessor
(
DataProcessor
):
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the MRPC data set (GLUE version)."""
...
@@ -302,7 +336,7 @@ class QnliProcessor(DataProcessor):
...
@@ -302,7 +336,7 @@ class QnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev_matched"
)
"dev_matched"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
...
@@ -387,168 +421,6 @@ class WnliProcessor(DataProcessor):
...
@@ -387,168 +421,6 @@ class WnliProcessor(DataProcessor):
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
"""
Loads a data file into a list of `InputBatch`s
"""
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
output_token_type
=
True
,
max_length
=
max_seq_length
,
truncate_first_sequence
=
True
# We're truncating the first sequence as a priority
)
input_ids
,
segment_ids
=
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
if
output_mode
==
"classification"
:
label_id
=
label_map
[
example
.
label
]
elif
output_mode
==
"regression"
:
label_id
=
float
(
example
.
label
)
else
:
raise
KeyError
(
output_mode
)
if
ex_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid: %s"
%
(
example
.
guid
))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
))
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
tokens_b
.
pop
()
def
simple_accuracy
(
preds
,
labels
):
return
(
preds
==
labels
).
mean
()
def
acc_and_f1
(
preds
,
labels
):
acc
=
simple_accuracy
(
preds
,
labels
)
f1
=
f1_score
(
y_true
=
labels
,
y_pred
=
preds
)
return
{
"acc"
:
acc
,
"f1"
:
f1
,
"acc_and_f1"
:
(
acc
+
f1
)
/
2
,
}
def
pearson_and_spearman
(
preds
,
labels
):
pearson_corr
=
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
spearmanr
(
preds
,
labels
)[
0
]
return
{
"pearson"
:
pearson_corr
,
"spearmanr"
:
spearman_corr
,
"corr"
:
(
pearson_corr
+
spearman_corr
)
/
2
,
}
def
compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"cola"
:
return
{
"mcc"
:
matthews_corrcoef
(
labels
,
preds
)}
elif
task_name
==
"sst-2"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mrpc"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"sts-b"
:
return
pearson_and_spearman
(
preds
,
labels
)
elif
task_name
==
"qqp"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"mnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mnli-mm"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"qnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"rte"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"wnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
processors
=
{
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli-mm"
:
MnliMismatchedProcessor
,
"mrpc"
:
MrpcProcessor
,
"sst-2"
:
Sst2Processor
,
"sts-b"
:
StsbProcessor
,
"qqp"
:
QqpProcessor
,
"qnli"
:
QnliProcessor
,
"rte"
:
RteProcessor
,
"wnli"
:
WnliProcessor
,
}
output_modes
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli-mm"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
"qqp"
:
"classification"
,
"qnli"
:
"classification"
,
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
}
GLUE_TASKS_NUM_LABELS
=
{
GLUE_TASKS_NUM_LABELS
=
{
"cola"
:
2
,
"cola"
:
2
,
"mnli"
:
3
,
"mnli"
:
3
,
...
@@ -559,4 +431,4 @@ GLUE_TASKS_NUM_LABELS = {
...
@@ -559,4 +431,4 @@ GLUE_TASKS_NUM_LABELS = {
"qnli"
:
2
,
"qnli"
:
2
,
"rte"
:
2
,
"rte"
:
2
,
"wnli"
:
2
,
"wnli"
:
2
,
}
}
\ No newline at end of file
pytorch_transformers/preprocessing/utils.py
0 → 100644
View file @
f09e5ece
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
sys
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
def
simple_accuracy
(
preds
,
labels
):
return
(
preds
==
labels
).
mean
()
def
acc_and_f1
(
preds
,
labels
):
acc
=
simple_accuracy
(
preds
,
labels
)
f1
=
f1_score
(
y_true
=
labels
,
y_pred
=
preds
)
return
{
"acc"
:
acc
,
"f1"
:
f1
,
"acc_and_f1"
:
(
acc
+
f1
)
/
2
,
}
def
pearson_and_spearman
(
preds
,
labels
):
pearson_corr
=
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
spearmanr
(
preds
,
labels
)[
0
]
return
{
"pearson"
:
pearson_corr
,
"spearmanr"
:
spearman_corr
,
"corr"
:
(
pearson_corr
+
spearman_corr
)
/
2
,
}
def
compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"cola"
:
return
{
"mcc"
:
matthews_corrcoef
(
labels
,
preds
)}
elif
task_name
==
"sst-2"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mrpc"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"sts-b"
:
return
pearson_and_spearman
(
preds
,
labels
)
elif
task_name
==
"qqp"
:
return
acc_and_f1
(
preds
,
labels
)
elif
task_name
==
"mnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"mnli-mm"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"qnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"rte"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
elif
task_name
==
"wnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment