Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6a17b3c5
Unverified
Commit
6a17b3c5
authored
Sep 27, 2019
by
Thomas Wolf
Committed by
GitHub
Sep 27, 2019
Browse files
Merge pull request #1355 from agrinh/master
Fix tensorflow_dataset glue support
parents
04e9a6f5
795b3e76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+64
-4
transformers/data/processors/utils.py
transformers/data/processors/utils.py
+9
-0
No files found.
transformers/data/processors/glue.py
View file @
6a17b3c5
...
...
@@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d"
%
(
ex_index
))
if
is_tf_dataset
:
example
=
InputExample
(
example
[
'idx'
].
numpy
(),
example
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
example
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
example
[
'label'
].
numpy
()))
example
=
processor
.
get_example_from_tensor_dict
(
example
)
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
...
...
@@ -157,6 +154,13 @@ def glue_convert_examples_to_features(examples, tokenizer,
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
...
...
@@ -190,6 +194,13 @@ class MrpcProcessor(DataProcessor):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'premise'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'hypothesis'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -233,6 +244,13 @@ class MnliMismatchedProcessor(MnliProcessor):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -262,6 +280,13 @@ class ColaProcessor(DataProcessor):
class
Sst2Processor
(
DataProcessor
):
"""Processor for the SST-2 data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -293,6 +318,13 @@ class Sst2Processor(DataProcessor):
class
StsbProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -325,6 +357,13 @@ class StsbProcessor(DataProcessor):
class
QqpProcessor
(
DataProcessor
):
"""Processor for the QQP data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'question2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -360,6 +399,13 @@ class QqpProcessor(DataProcessor):
class
QnliProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -393,6 +439,13 @@ class QnliProcessor(DataProcessor):
class
RteProcessor
(
DataProcessor
):
"""Processor for the RTE data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -425,6 +478,13 @@ class RteProcessor(DataProcessor):
class
WnliProcessor
(
DataProcessor
):
"""Processor for the WNLI data set (GLUE version)."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
transformers/data/processors/utils.py
View file @
6a17b3c5
...
...
@@ -86,6 +86,15 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise
NotImplementedError
()
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment