Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6a17b3c5
Unverified
Commit
6a17b3c5
authored
Sep 27, 2019
by
Thomas Wolf
Committed by
GitHub
Sep 27, 2019
Browse files
Merge pull request #1355 from agrinh/master
Fix tensorflow_dataset glue support
parents
04e9a6f5
795b3e76
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+64
-4
transformers/data/processors/utils.py
transformers/data/processors/utils.py
+9
-0
No files found.
transformers/data/processors/glue.py
View file @
6a17b3c5
...
@@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
...
@@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
if
ex_index
%
10000
==
0
:
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d"
%
(
ex_index
))
logger
.
info
(
"Writing example %d"
%
(
ex_index
))
if
is_tf_dataset
:
if
is_tf_dataset
:
example
=
InputExample
(
example
[
'idx'
].
numpy
(),
example
=
processor
.
get_example_from_tensor_dict
(
example
)
example
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
example
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
example
[
'label'
].
numpy
()))
inputs
=
tokenizer
.
encode_plus
(
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_a
,
...
@@ -157,6 +154,13 @@ def glue_convert_examples_to_features(examples, tokenizer,
...
@@ -157,6 +154,13 @@ def glue_convert_examples_to_features(examples, tokenizer,
class
MrpcProcessor
(
DataProcessor
):
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the MRPC data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
logger
.
info
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
logger
.
info
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
...
@@ -190,6 +194,13 @@ class MrpcProcessor(DataProcessor):
...
@@ -190,6 +194,13 @@ class MrpcProcessor(DataProcessor):
class
MnliProcessor
(
DataProcessor
):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'premise'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'hypothesis'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -233,6 +244,13 @@ class MnliMismatchedProcessor(MnliProcessor):
...
@@ -233,6 +244,13 @@ class MnliMismatchedProcessor(MnliProcessor):
class
ColaProcessor
(
DataProcessor
):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
"""Processor for the CoLA data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -262,6 +280,13 @@ class ColaProcessor(DataProcessor):
...
@@ -262,6 +280,13 @@ class ColaProcessor(DataProcessor):
class
Sst2Processor
(
DataProcessor
):
class
Sst2Processor
(
DataProcessor
):
"""Processor for the SST-2 data set (GLUE version)."""
"""Processor for the SST-2 data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -293,6 +318,13 @@ class Sst2Processor(DataProcessor):
...
@@ -293,6 +318,13 @@ class Sst2Processor(DataProcessor):
class
StsbProcessor
(
DataProcessor
):
class
StsbProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
"""Processor for the STS-B data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -325,6 +357,13 @@ class StsbProcessor(DataProcessor):
...
@@ -325,6 +357,13 @@ class StsbProcessor(DataProcessor):
class
QqpProcessor
(
DataProcessor
):
class
QqpProcessor
(
DataProcessor
):
"""Processor for the QQP data set (GLUE version)."""
"""Processor for the QQP data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'question2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -360,6 +399,13 @@ class QqpProcessor(DataProcessor):
...
@@ -360,6 +399,13 @@ class QqpProcessor(DataProcessor):
class
QnliProcessor
(
DataProcessor
):
class
QnliProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
"""Processor for the QNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -393,6 +439,13 @@ class QnliProcessor(DataProcessor):
...
@@ -393,6 +439,13 @@ class QnliProcessor(DataProcessor):
class
RteProcessor
(
DataProcessor
):
class
RteProcessor
(
DataProcessor
):
"""Processor for the RTE data set (GLUE version)."""
"""Processor for the RTE data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -425,6 +478,13 @@ class RteProcessor(DataProcessor):
...
@@ -425,6 +478,13 @@ class RteProcessor(DataProcessor):
class
WnliProcessor
(
DataProcessor
):
class
WnliProcessor
(
DataProcessor
):
"""Processor for the WNLI data set (GLUE version)."""
"""Processor for the WNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
...
transformers/data/processors/utils.py
View file @
6a17b3c5
...
@@ -86,6 +86,15 @@ class InputFeatures(object):
...
@@ -86,6 +86,15 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for sequence classification data sets."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise
NotImplementedError
()
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment