"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "839bfaedb21e42edee093b9e21e2c2f1ea7514f0"
Commit 78863f6b authored by thomwolf's avatar thomwolf
Browse files

fix tokenizer to tensors

parent 8a618e0a
...@@ -27,6 +27,8 @@ from .file_utils import cached_path, is_tf_available, is_torch_available ...@@ -27,6 +27,8 @@ from .file_utils import cached_path, is_tf_available, is_torch_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
if is_torch_available()
import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -849,7 +851,11 @@ class PreTrainedTokenizer(object): ...@@ -849,7 +851,11 @@ class PreTrainedTokenizer(object):
if return_tensors == 'tf' and is_tf_available(): if return_tensors == 'tf' and is_tf_available():
sequence = tf.constant(sequence) sequence = tf.constant(sequence)
token_type_ids = tf.constant(token_type_ids) token_type_ids = tf.constant(token_type_ids)
elif return_tensors = 'pt' and is elif return_tensors == 'pt' and is_torch_available():
sequence = torch.tensor(sequence)
token_type_ids = torch.tensor(token_type_ids)
elif return_tensors is not None:
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
encoded_inputs["input_ids"] = sequence encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids encoded_inputs["token_type_ids"] = token_type_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment