Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c4acc3a8
Commit
c4acc3a8
authored
Sep 25, 2019
by
thomwolf
Browse files
let encode accept tensor inputs
parent
e8e956db
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
16 deletions
+47
-16
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-7
pytorch_transformers/file_utils.py
pytorch_transformers/file_utils.py
+20
-0
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+25
-9
No files found.
pytorch_transformers/__init__.py
View file @
c4acc3a8
...
...
@@ -163,10 +163,5 @@ if _tf_available and _torch_available:
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
def
is_torch_available
():
return
_torch_available
def
is_tf_available
():
return
_tf_available
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
,
is_tf_available
,
is_torch_available
)
\ No newline at end of file
pytorch_transformers/file_utils.py
View file @
c4acc3a8
...
...
@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
import
requests
from
tqdm
import
tqdm
try
:
import
tensorflow
as
tf
assert
int
(
tf
.
__version__
[
0
])
>=
2
_tf_available
=
True
# pylint: disable=invalid-name
except
(
ImportError
,
AssertionError
):
_tf_available
=
False
# pylint: disable=invalid-name
try
:
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
except
ImportError
:
_torch_available
=
False
# pylint: disable=invalid-name
try
:
from
torch.hub
import
_get_torch_home
torch_cache_home
=
_get_torch_home
()
...
...
@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
is_torch_available
():
return
_torch_available
def
is_tf_available
():
return
_tf_available
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
...
...
pytorch_transformers/tokenization_utils.py
View file @
c4acc3a8
...
...
@@ -23,7 +23,10 @@ import six
import
copy
from
io
import
open
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
to their model.
**kwargs: passed to the `self.tokenize()` method
"""
if
is_tf_available
():
is_tf_tensor
=
False
if
isinstance
(
text
,
tf
.
Tensor
):
text
=
text
.
numpy
()
is_tf_tensor
=
True
if
isinstance
(
text
,
bytes
):
text
=
text
.
decode
(
'utf-8'
)
if
text_pair
is
None
:
if
add_special_tokens
:
return
self
.
add_special_tokens_single_sentence
(
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
)))
output
=
self
.
add_special_tokens_single_sentence
(
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
)))
else
:
output
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
else
:
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
,
**
kwargs
)]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text_pair
,
**
kwargs
)]
if
add_special_tokens
:
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
)
output
=
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
)
else
:
return
first_sentence_tokens
,
second_sentence_tokens
output
=
first_sentence_tokens
,
second_sentence_tokens
if
is_tf_available
()
and
is_tf_tensor
:
output
=
tf
.
constant
(
output
)
return
output
def
add_special_tokens_single_sentence
(
self
,
token_ids
):
logger
.
warning
(
"This tokenizer does not make use of special tokens. The sequence has been returned with no modification."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment