Commit c4acc3a8 authored by thomwolf's avatar thomwolf
Browse files

let encode accept tensor inputs

parent e8e956db
......@@ -163,10 +163,5 @@ if _tf_available and _torch_available:
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
\ No newline at end of file
......@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
import requests
from tqdm import tqdm
try:
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
import torch
_torch_available = True # pylint: disable=invalid-name
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
......@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
if not six.PY2:
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
......
......@@ -23,7 +23,10 @@ import six
import copy
from io import open
from .file_utils import cached_path
from .file_utils import cached_path, is_tf_available
if is_tf_available():
import tensorflow as tf
logger = logging.getLogger(__name__)
......@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
to their model.
**kwargs: passed to the `self.tokenize()` method
"""
if is_tf_available():
is_tf_tensor = False
if isinstance(text, tf.Tensor):
text = text.numpy()
is_tf_tensor = True
if isinstance(text, bytes):
text = text.decode('utf-8')
if text_pair is None:
if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
output = self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
else:
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
output = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens:
output = self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
output = first_sentence_tokens, second_sentence_tokens
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
return first_sentence_tokens, second_sentence_tokens
if is_tf_available() and is_tf_tensor:
output = tf.constant(output)
return output
def add_special_tokens_single_sentence(self, token_ids):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment