Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c4acc3a8
Commit
c4acc3a8
authored
Sep 25, 2019
by
thomwolf
Browse files
let encode accept tensor inputs
parent
e8e956db
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
16 deletions
+47
-16
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-7
pytorch_transformers/file_utils.py
pytorch_transformers/file_utils.py
+20
-0
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+25
-9
No files found.
pytorch_transformers/__init__.py
View file @
c4acc3a8
...
...
@@ -163,10 +163,5 @@ if _tf_available and _torch_available:
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
def
is_torch_available
():
return
_torch_available
def
is_tf_available
():
return
_tf_available
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
,
is_tf_available
,
is_torch_available
)
\ No newline at end of file
pytorch_transformers/file_utils.py
View file @
c4acc3a8
...
...
@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
import
requests
from
tqdm
import
tqdm
try
:
import
tensorflow
as
tf
assert
int
(
tf
.
__version__
[
0
])
>=
2
_tf_available
=
True
# pylint: disable=invalid-name
except
(
ImportError
,
AssertionError
):
_tf_available
=
False
# pylint: disable=invalid-name
try
:
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
except
ImportError
:
_torch_available
=
False
# pylint: disable=invalid-name
try
:
from
torch.hub
import
_get_torch_home
torch_cache_home
=
_get_torch_home
()
...
...
@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
is_torch_available
():
return
_torch_available
def
is_tf_available
():
return
_tf_available
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
...
...
pytorch_transformers/tokenization_utils.py
View file @
c4acc3a8
...
...
@@ -23,7 +23,10 @@ import six
import
copy
from
io
import
open
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
to their model.
**kwargs: passed to the `self.tokenize()` method
"""
if
is_tf_available
():
is_tf_tensor
=
False
if
isinstance
(
text
,
tf
.
Tensor
):
text
=
text
.
numpy
()
is_tf_tensor
=
True
if
isinstance
(
text
,
bytes
):
text
=
text
.
decode
(
'utf-8'
)
if
text_pair
is
None
:
if
add_special_tokens
:
return
self
.
add_special_tokens_single_sentence
(
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
)))
output
=
self
.
add_special_tokens_single_sentence
(
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
)))
else
:
output
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
else
:
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
,
**
kwargs
)]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text_pair
,
**
kwargs
)]
if
add_special_tokens
:
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
)
output
=
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
)
else
:
return
first_sentence_tokens
,
second_sentence_tokens
output
=
first_sentence_tokens
,
second_sentence_tokens
if
is_tf_available
()
and
is_tf_tensor
:
output
=
tf
.
constant
(
output
)
return
output
def
add_special_tokens_single_sentence
(
self
,
token_ids
):
logger
.
warning
(
"This tokenizer does not make use of special tokens. The sequence has been returned with no modification."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment