Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
280db79a
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dcb183f4bdcd9491efb68b3c28d51614a11e59dc"
Unverified
Commit
280db79a
authored
Jan 14, 2021
by
Lysandre Debut
Committed by
GitHub
Jan 14, 2021
Browse files
BatchEncoding.to with device with tests (#9584)
parent
8bf27075
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
1 deletion
+11
-1
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+7
-1
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+4
-0
No files found.
src/transformers/tokenization_utils_base.py
View file @
280db79a
...
@@ -65,6 +65,12 @@ def _is_torch(x):
...
@@ -65,6 +65,12 @@ def _is_torch(x):
return
isinstance
(
x
,
torch
.
Tensor
)
return
isinstance
(
x
,
torch
.
Tensor
)
def
_is_torch_device
(
x
):
import
torch
return
isinstance
(
x
,
torch
.
device
)
def
_is_tensorflow
(
x
):
def
_is_tensorflow
(
x
):
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -801,7 +807,7 @@ class BatchEncoding(UserDict):
...
@@ -801,7 +807,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
# into a HalfTensor
if
isinstance
(
device
,
str
)
or
is
instance
(
device
,
torch
.
device
)
or
isinstance
(
device
,
int
):
if
isinstance
(
device
,
str
)
or
_
is
_torch_device
(
device
)
or
isinstance
(
device
,
int
):
self
.
data
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
self
.
data
.
items
()}
self
.
data
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
self
.
data
.
items
()}
else
:
else
:
logger
.
warning
(
logger
.
warning
(
...
...
tests/test_tokenization_common.py
View file @
280db79a
...
@@ -1704,6 +1704,10 @@ class TokenizerTesterMixin:
...
@@ -1704,6 +1704,10 @@ class TokenizerTesterMixin:
first_ten_tokens
=
list
(
tokenizer
.
get_vocab
().
keys
())[:
10
]
first_ten_tokens
=
list
(
tokenizer
.
get_vocab
().
keys
())[:
10
]
sequence
=
" "
.
join
(
first_ten_tokens
)
sequence
=
" "
.
join
(
first_ten_tokens
)
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_tensors
=
"pt"
)
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_tensors
=
"pt"
)
# Ensure that the BatchEncoding.to() method works.
encoded_sequence
.
to
(
model
.
device
)
batch_encoded_sequence
=
tokenizer
.
batch_encode_plus
([
sequence
,
sequence
],
return_tensors
=
"pt"
)
batch_encoded_sequence
=
tokenizer
.
batch_encode_plus
([
sequence
,
sequence
],
return_tensors
=
"pt"
)
# This should not fail
# This should not fail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment