Commit 8a618e0a authored by thomwolf's avatar thomwolf
Browse files

clean up __init__

parent 3b7fb48c
...@@ -16,7 +16,21 @@ import logging ...@@ -16,7 +16,21 @@ import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Tokenizer # Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels)
if is_sklearn_available():
from .data import glue_compute_metrics
# Tokenizers
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
from .tokenization_auto import AutoTokenizer from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
...@@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH ...@@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling # Modeling
try: if is_torch_available():
import torch
_torch_available = True # pylint: disable=invalid-name
except ImportError:
_torch_available = False # pylint: disable=invalid-name
if _torch_available:
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
...@@ -87,14 +95,7 @@ if _torch_available: ...@@ -87,14 +95,7 @@ if _torch_available:
# TensorFlow # TensorFlow
try: if is_tf_available():
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
if _tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
...@@ -151,7 +152,8 @@ if _tf_available: ...@@ -151,7 +152,8 @@ if _tf_available:
load_distilbert_pt_weights_in_tf2, load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
if _tf_available and _torch_available: # TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available():
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model, load_pytorch_weights_in_tf2_model,
...@@ -159,17 +161,3 @@ if _tf_available and _torch_available: ...@@ -159,17 +161,3 @@ if _tf_available and _torch_available:
load_tf2_checkpoint_in_pytorch_model, load_tf2_checkpoint_in_pytorch_model,
load_tf2_weights_in_pytorch_model, load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model) load_tf2_model_in_pytorch_model)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels)
if is_sklearn_available():
from .data import glue_compute_metrics
...@@ -23,7 +23,7 @@ import six ...@@ -23,7 +23,7 @@ import six
import copy import copy
from io import open from io import open
from .file_utils import cached_path, is_tf_available from .file_utils import cached_path, is_tf_available, is_torch_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -690,7 +690,15 @@ class PreTrainedTokenizer(object): ...@@ -690,7 +690,15 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...@@ -705,9 +713,24 @@ class PreTrainedTokenizer(object): ...@@ -705,9 +713,24 @@ class PreTrainedTokenizer(object):
`convert_tokens_to_ids` method) `convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
encoded_inputs = self.encode_plus(text, text_pair=text_pair, add_special_tokens=add_special_tokens, **kwargs) encoded_inputs = self.encode_plus(text,
text_pair=text_pair,
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors,
**kwargs)
return encoded_inputs["input_ids"] return encoded_inputs["input_ids"]
...@@ -718,10 +741,11 @@ class PreTrainedTokenizer(object): ...@@ -718,10 +741,11 @@ class PreTrainedTokenizer(object):
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncate_first_sequence=True,
return_tensors=None,
**kwargs): **kwargs):
""" """
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args: Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
...@@ -738,6 +762,8 @@ class PreTrainedTokenizer(object): ...@@ -738,6 +762,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated. will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
...@@ -759,10 +785,12 @@ class PreTrainedTokenizer(object): ...@@ -759,10 +785,12 @@ class PreTrainedTokenizer(object):
max_length=max_length, max_length=max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence) truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, truncate_first_sequence=True): def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates It adds special tokens, truncates
...@@ -782,6 +810,8 @@ class PreTrainedTokenizer(object): ...@@ -782,6 +810,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided, truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead. than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
Return: Return:
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given. a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
...@@ -816,6 +846,11 @@ class PreTrainedTokenizer(object): ...@@ -816,6 +846,11 @@ class PreTrainedTokenizer(object):
sequence = ids + pair_ids if pair else ids sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
if return_tensors == 'tf' and is_tf_available():
sequence = tf.constant(sequence)
token_type_ids = tf.constant(token_type_ids)
elif return_tensors = 'pt' and is
encoded_inputs["input_ids"] = sequence encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids encoded_inputs["token_type_ids"] = token_type_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment