"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7a0cf0ec936864c693b47aefc63472f09bccfe06"
Commit 795b3e76 authored by Agrin Hilmkil's avatar Agrin Hilmkil
Browse files

Add docstring for processor method

parent e31a4728
...@@ -155,6 +155,7 @@ class MrpcProcessor(DataProcessor): ...@@ -155,6 +155,7 @@ class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the MRPC data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'),
...@@ -194,6 +195,7 @@ class MnliProcessor(DataProcessor): ...@@ -194,6 +195,7 @@ class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['premise'].numpy().decode('utf-8'), tensor_dict['premise'].numpy().decode('utf-8'),
tensor_dict['hypothesis'].numpy().decode('utf-8'), tensor_dict['hypothesis'].numpy().decode('utf-8'),
...@@ -243,6 +245,7 @@ class ColaProcessor(DataProcessor): ...@@ -243,6 +245,7 @@ class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict['sentence'].numpy().decode('utf-8'),
None, None,
...@@ -278,6 +281,7 @@ class Sst2Processor(DataProcessor): ...@@ -278,6 +281,7 @@ class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version).""" """Processor for the SST-2 data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict['sentence'].numpy().decode('utf-8'),
None, None,
...@@ -315,6 +319,7 @@ class StsbProcessor(DataProcessor): ...@@ -315,6 +319,7 @@ class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'),
...@@ -353,6 +358,7 @@ class QqpProcessor(DataProcessor): ...@@ -353,6 +358,7 @@ class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version).""" """Processor for the QQP data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question1'].numpy().decode('utf-8'), tensor_dict['question1'].numpy().decode('utf-8'),
tensor_dict['question2'].numpy().decode('utf-8'), tensor_dict['question2'].numpy().decode('utf-8'),
...@@ -394,6 +400,7 @@ class QnliProcessor(DataProcessor): ...@@ -394,6 +400,7 @@ class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the QNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question'].numpy().decode('utf-8'), tensor_dict['question'].numpy().decode('utf-8'),
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict['sentence'].numpy().decode('utf-8'),
...@@ -433,6 +440,7 @@ class RteProcessor(DataProcessor): ...@@ -433,6 +440,7 @@ class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version).""" """Processor for the RTE data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'),
...@@ -471,6 +479,7 @@ class WnliProcessor(DataProcessor): ...@@ -471,6 +479,7 @@ class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version).""" """Processor for the WNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict['sentence2'].numpy().decode('utf-8'),
......
...@@ -86,6 +86,15 @@ class InputFeatures(object): ...@@ -86,6 +86,15 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment