Commit c4acc3a8 authored by thomwolf's avatar thomwolf
Browse files

let encode accept tensor inputs

parent e8e956db
...@@ -163,10 +163,5 @@ if _tf_available and _torch_available: ...@@ -163,10 +163,5 @@ if _tf_available and _torch_available:
# Files and general utilities # Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings, cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
def is_torch_available(): \ No newline at end of file
return _torch_available
def is_tf_available():
return _tf_available
...@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError ...@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
import requests import requests
from tqdm import tqdm from tqdm import tqdm
try:
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
import torch
_torch_available = True # pylint: disable=invalid-name
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try: try:
from torch.hub import _get_torch_home from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home() torch_cache_home = _get_torch_home()
...@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json" ...@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
if not six.PY2: if not six.PY2:
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
......
...@@ -23,7 +23,10 @@ import six ...@@ -23,7 +23,10 @@ import six
import copy import copy
from io import open from io import open
from .file_utils import cached_path from .file_utils import cached_path, is_tf_available
if is_tf_available():
import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object): ...@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
to their model. to their model.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
if is_tf_available():
is_tf_tensor = False
if isinstance(text, tf.Tensor):
text = text.numpy()
is_tf_tensor = True
if isinstance(text, bytes):
text = text.decode('utf-8')
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) output = self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
else: else:
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) output = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if add_special_tokens:
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] output = self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
output = first_sentence_tokens, second_sentence_tokens
if add_special_tokens: if is_tf_available() and is_tf_tensor:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) output = tf.constant(output)
else:
return first_sentence_tokens, second_sentence_tokens return output
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sentence(self, token_ids):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment